From 73e426b081e2d76be2b7ead89c39ec55097a29cd Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Thu, 28 Mar 2024 16:05:33 +0100 Subject: [PATCH] User: email verification completion (#85259) * TempUser: Include InvitedById in TempUserDTO * Extract email verfication completion flow to service --- pkg/api/user.go | 82 +--------- pkg/api/user_test.go | 4 +- pkg/services/temp_user/model.go | 1 + pkg/services/temp_user/tempuserimpl/store.go | 23 +-- pkg/services/temp_user/tempusertest/fake.go | 16 ++ pkg/services/user/model.go | 9 +- pkg/services/user/user.go | 3 +- pkg/services/user/userimpl/store.go | 9 +- pkg/services/user/userimpl/verifier.go | 79 +++++++++- pkg/services/user/userimpl/verifier_test.go | 153 ++++++++++++++++++- pkg/services/user/usertest/fake.go | 4 + 11 files changed, 275 insertions(+), 108 deletions(-) diff --git a/pkg/api/user.go b/pkg/api/user.go index 49a9572a79a..eadacc1ed6d 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -4,11 +4,9 @@ import ( "context" "errors" "net/http" - "net/mail" "net/url" "strconv" "strings" - "time" "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" @@ -17,7 +15,6 @@ import ( "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/team" - tempuser "github.com/grafana/grafana/pkg/services/temp_user" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" @@ -275,7 +272,7 @@ func (hs *HTTPServer) handleUpdateUser(ctx context.Context, cmd user.UpdateUserC } func (hs *HTTPServer) verifyEmailUpdate(ctx context.Context, email string, field user.UpdateEmailActionType, usr *user.User) response.Response { - if err := hs.userVerifier.VerifyEmail(ctx, user.VerifyEmailCommand{ + if err := hs.userVerifier.Start(ctx, user.StartVerifyEmailCommand{ User: *usr, Email: email, Action: field, @@ -295,37 +292,15 @@ func (hs *HTTPServer) verifyEmailUpdate(ctx context.Context, email string, field // Responses: // 302: okResponse func (hs *HTTPServer) UpdateUserEmail(c *contextmodel.ReqContext) response.Response { - var err error - - q := c.Req.URL.Query() - code, err := url.QueryUnescape(q.Get("code")) + code, err := url.QueryUnescape(c.Req.URL.Query().Get("code")) if err != nil || code == "" { return hs.RedirectResponseWithError(c, errors.New("bad request data")) } - tempUser, err := hs.validateEmailCode(c.Req.Context(), code) - if err != nil { + if err := hs.userVerifier.Complete(c.Req.Context(), user.CompleteEmailVerifyCommand{Code: code}); err != nil { return hs.RedirectResponseWithError(c, err) } - cmd, err := hs.updateCmdFromEmailVerification(c.Req.Context(), tempUser) - if err != nil { - return hs.RedirectResponseWithError(c, err) - } - - if err := hs.userService.Update(c.Req.Context(), cmd); err != nil { - if errors.Is(err, user.ErrCaseInsensitive) { - return hs.RedirectResponseWithError(c, errors.New("update would result in user login conflict")) - } - return hs.RedirectResponseWithError(c, errors.New("failed to update user")) - } - - // Mark temp user as completed - updateTmpUserCmd := tempuser.UpdateTempUserStatusCommand{Code: code, Status: tempuser.TmpUserEmailUpdateCompleted} - if err := hs.tempUserService.UpdateTempUserStatus(c.Req.Context(), &updateTmpUserCmd); err != nil { - return hs.RedirectResponseWithError(c, errors.New("failed to update verification status")) - } - return response.Redirect(hs.Cfg.AppSubURL + "/profile") } @@ -694,57 +669,6 @@ func getUserID(c *contextmodel.ReqContext) (int64, *response.NormalResponse) { return userID, nil } -func (hs *HTTPServer) updateCmdFromEmailVerification(ctx context.Context, tempUser *tempuser.TempUserDTO) (*user.UpdateUserCommand, error) { - userQuery := user.GetUserByLoginQuery{LoginOrEmail: tempUser.InvitedByLogin} - usr, err := hs.userService.GetByLogin(ctx, &userQuery) - if err != nil { - if errors.Is(err, user.ErrUserNotFound) { - return nil, user.ErrUserNotFound - } - return nil, errors.New("failed to get user") - } - - cmd := &user.UpdateUserCommand{UserID: usr.ID, Email: tempUser.Email} - - switch tempUser.Name { - case string(user.EmailUpdateAction): - // User updated the email field - if _, err := mail.ParseAddress(usr.Login); err == nil { - // If username was also an email, we update it to keep it in sync with the email field - cmd.Login = tempUser.Email - } - case string(user.LoginUpdateAction): - // User updated the username field with a new email - cmd.Login = tempUser.Email - default: - return nil, errors.New("trying to update email on unknown field") - } - return cmd, nil -} - -func (hs *HTTPServer) validateEmailCode(ctx context.Context, code string) (*tempuser.TempUserDTO, error) { - tempUserQuery := tempuser.GetTempUserByCodeQuery{Code: code} - tempUser, err := hs.tempUserService.GetTempUserByCode(ctx, &tempUserQuery) - if err != nil { - if errors.Is(err, tempuser.ErrTempUserNotFound) { - return nil, errors.New("invalid email verification code") - } - return nil, errors.New("failed to read temp user") - } - - if tempUser.Status != tempuser.TmpUserEmailUpdateStarted { - return nil, errors.New("invalid email verification code") - } - if !tempUser.EmailSent { - return nil, errors.New("verification email was not recorded as sent") - } - if tempUser.EmailSentOn.Add(hs.Cfg.VerificationEmailMaxLifetime).Before(time.Now()) { - return nil, errors.New("invalid email verification code") - } - - return tempUser, nil -} - // swagger:parameters searchUsers type SearchUsersParams struct { // Limit the maximum number of users to return per page diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index e10f15c5150..1f445a128ca 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -397,7 +397,7 @@ func setupUpdateEmailTests(t *testing.T, cfg *setting.Cfg) (*user.User, *HTTPSer require.NoError(t, err) nsMock := notifications.MockNotificationService() - verifier := userimpl.ProvideVerifier(userSvc, tempUserService, nsMock) + verifier := userimpl.ProvideVerifier(cfg, userSvc, tempUserService, nsMock) hs := &HTTPServer{ Cfg: cfg, @@ -620,7 +620,7 @@ func TestUser_UpdateEmail(t *testing.T) { hs.tempUserService = tempUserSvc hs.NotificationService = nsMock hs.SecretsService = fakes.NewFakeSecretsService() - hs.userVerifier = userimpl.ProvideVerifier(userSvc, tempUserSvc, nsMock) + hs.userVerifier = userimpl.ProvideVerifier(settings, userSvc, tempUserSvc, nsMock) // User is internal hs.authInfoService = &authinfotest.FakeService{ExpectedError: user.ErrUserNotFound} }) diff --git a/pkg/services/temp_user/model.go b/pkg/services/temp_user/model.go index b0949a79381..dbec94be914 100644 --- a/pkg/services/temp_user/model.go +++ b/pkg/services/temp_user/model.go @@ -96,6 +96,7 @@ type TempUserDTO struct { Name string `json:"name"` Email string `json:"email"` Role org.RoleType `json:"role"` + InvitedByID int64 `json:"-" xorm:"invited_by_id"` InvitedByLogin string `json:"invitedByLogin"` InvitedByEmail string `json:"invitedByEmail"` InvitedByName string `json:"invitedByName"` diff --git a/pkg/services/temp_user/tempuserimpl/store.go b/pkg/services/temp_user/tempuserimpl/store.go index 9fbef536b50..ca6b14dc2b8 100644 --- a/pkg/services/temp_user/tempuserimpl/store.go +++ b/pkg/services/temp_user/tempuserimpl/store.go @@ -129,18 +129,19 @@ func (ss *xormStore) GetTempUserByCode(ctx context.Context, query *tempuser.GetT tu.id as id, tu.org_id as org_id, tu.email as email, - tu.name as name, - tu.role as role, - tu.code as code, - tu.status as status, - tu.email_sent as email_sent, - tu.email_sent_on as email_sent_on, - tu.created as created, - u.login as invited_by_login, - u.name as invited_by_name, - u.email as invited_by_email + tu.name as name, + tu.role as role, + tu.code as code, + tu.status as status, + tu.email_sent as email_sent, + tu.email_sent_on as email_sent_on, + tu.created as created, + tu.invited_by_user_id as invited_by_id, + u.login as invited_by_login, + u.name as invited_by_name, + u.email as invited_by_email FROM ` + ss.db.GetDialect().Quote("temp_user") + ` as tu - LEFT OUTER JOIN ` + ss.db.GetDialect().Quote("user") + ` as u on u.id = tu.invited_by_user_id + LEFT OUTER JOIN ` + ss.db.GetDialect().Quote("user") + ` as u on u.id = tu.invited_by_user_id WHERE tu.code=?` var tempUser tempuser.TempUserDTO diff --git a/pkg/services/temp_user/tempusertest/fake.go b/pkg/services/temp_user/tempusertest/fake.go index 17d8c2c1330..0bd9b803a88 100644 --- a/pkg/services/temp_user/tempusertest/fake.go +++ b/pkg/services/temp_user/tempusertest/fake.go @@ -10,11 +10,27 @@ var _ tempuser.Service = (*FakeTempUserService)(nil) type FakeTempUserService struct { tempuser.Service + GetTempUserByCodeFN func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) + UpdateTempUserStatusFN func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error CreateTempUserFN func(ctx context.Context, cmd *tempuser.CreateTempUserCommand) (*tempuser.TempUser, error) ExpirePreviousVerificationsFN func(ctx context.Context, cmd *tempuser.ExpirePreviousVerificationsCommand) error UpdateTempUserWithEmailSentFN func(ctx context.Context, cmd *tempuser.UpdateTempUserWithEmailSentCommand) error } +func (f *FakeTempUserService) GetTempUserByCode(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + if f.GetTempUserByCodeFN != nil { + return f.GetTempUserByCodeFN(ctx, query) + } + return nil, nil +} + +func (f *FakeTempUserService) UpdateTempUserStatus(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error { + if f.UpdateTempUserStatusFN != nil { + return f.UpdateTempUserStatusFN(ctx, cmd) + } + return nil +} + func (f *FakeTempUserService) CreateTempUser(ctx context.Context, cmd *tempuser.CreateTempUserCommand) (*tempuser.TempUser, error) { if f.CreateTempUserFN != nil { return f.CreateTempUserFN(ctx, cmd) diff --git a/pkg/services/user/model.go b/pkg/services/user/model.go index 1d7e8db976e..5a51644c9fb 100644 --- a/pkg/services/user/model.go +++ b/pkg/services/user/model.go @@ -82,7 +82,8 @@ type UpdateUserCommand struct { Login string `json:"login"` Theme string `json:"theme"` - UserID int64 `json:"-"` + UserID int64 `json:"-"` + EmailVerified *bool `json:"-"` } type ChangeUserPasswordCommand struct { @@ -220,12 +221,16 @@ type GetUserByIDQuery struct { ID int64 } -type VerifyEmailCommand struct { +type StartVerifyEmailCommand struct { User User Email string Action UpdateEmailActionType } +type CompleteEmailVerifyCommand struct { + Code string +} + type ErrCaseInsensitiveLoginConflict struct { Users []User } diff --git a/pkg/services/user/user.go b/pkg/services/user/user.go index 1466fed7c66..1282546d676 100644 --- a/pkg/services/user/user.go +++ b/pkg/services/user/user.go @@ -31,5 +31,6 @@ type Service interface { } type Verifier interface { - VerifyEmail(ctx context.Context, cmd VerifyEmailCommand) error + Start(ctx context.Context, cmd StartVerifyEmailCommand) error + Complete(ctx context.Context, cmd CompleteEmailVerifyCommand) error } diff --git a/pkg/services/user/userimpl/store.go b/pkg/services/user/userimpl/store.go index c91195c203b..9a8b337542f 100644 --- a/pkg/services/user/userimpl/store.go +++ b/pkg/services/user/userimpl/store.go @@ -315,7 +315,14 @@ func (ss *sqlStore) Update(ctx context.Context, cmd *user.UpdateUserCommand) err Updated: time.Now(), } - if _, err := sess.ID(cmd.UserID).Where(ss.notServiceAccountFilter()).Update(&user); err != nil { + q := sess.ID(cmd.UserID).Where(ss.notServiceAccountFilter()) + + if cmd.EmailVerified != nil { + q.UseBool("email_verified") + user.EmailVerified = *cmd.EmailVerified + } + + if _, err := q.Update(&user); err != nil { return err } diff --git a/pkg/services/user/userimpl/verifier.go b/pkg/services/user/userimpl/verifier.go index 719a530937a..d6ab7284a2f 100644 --- a/pkg/services/user/userimpl/verifier.go +++ b/pkg/services/user/userimpl/verifier.go @@ -4,26 +4,36 @@ import ( "context" "errors" "fmt" + "net/mail" + "time" "github.com/grafana/grafana/pkg/services/notifications" tempuser "github.com/grafana/grafana/pkg/services/temp_user" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" + "github.com/grafana/grafana/pkg/util/errutil" +) + +var ( + errInvalidCode = errutil.BadRequest("user.code.invalid", errutil.WithPublicMessage("Invalid verification code")) + errExpiredCode = errutil.BadRequest("user.code.expired", errutil.WithPublicMessage("Verification code has expired")) ) var _ user.Verifier = (*Verifier)(nil) -func ProvideVerifier(us user.Service, ts tempuser.Service, ns notifications.Service) *Verifier { - return &Verifier{us, ts, ns} +func ProvideVerifier(cfg *setting.Cfg, us user.Service, ts tempuser.Service, ns notifications.Service) *Verifier { + return &Verifier{cfg, us, ts, ns} } type Verifier struct { - us user.Service - ts tempuser.Service - ns notifications.Service + cfg *setting.Cfg + us user.Service + ts tempuser.Service + ns notifications.Service } -func (s *Verifier) VerifyEmail(ctx context.Context, cmd user.VerifyEmailCommand) error { +func (s *Verifier) Start(ctx context.Context, cmd user.StartVerifyEmailCommand) error { usr, err := s.us.GetByLogin(ctx, &user.GetUserByLoginQuery{ LoginOrEmail: cmd.Email, }) @@ -80,3 +90,60 @@ func (s *Verifier) VerifyEmail(ctx context.Context, cmd user.VerifyEmailCommand) return nil } + +func (s *Verifier) Complete(ctx context.Context, cmd user.CompleteEmailVerifyCommand) error { + tmpUsr, err := s.ts.GetTempUserByCode(ctx, &tempuser.GetTempUserByCodeQuery{Code: cmd.Code}) + if err != nil { + return errInvalidCode.Errorf("failed to verify code: %w", err) + } + + if tmpUsr.Status != tempuser.TmpUserEmailUpdateStarted { + return errInvalidCode.Errorf("wrong status for verification code: %s", tmpUsr.Status) + } + + if !tmpUsr.EmailSent { + return errInvalidCode.Errorf("email was not marked as sent") + } + + if tmpUsr.EmailSentOn.Add(s.cfg.VerificationEmailMaxLifetime).Before(time.Now()) { + return errExpiredCode.Errorf("verification code has expired") + } + + usr, err := s.us.GetByID(ctx, &user.GetUserByIDQuery{ID: tmpUsr.InvitedByID}) + if err != nil { + return err + } + + verified := true + update := &user.UpdateUserCommand{ + Email: tmpUsr.Email, + UserID: tmpUsr.InvitedByID, + EmailVerified: &verified, + } + switch tmpUsr.Name { + case string(user.EmailUpdateAction): + // User updated the email field + if _, err := mail.ParseAddress(usr.Login); err == nil { + // If username was also an email, we update it to keep it in sync with the email field + update.Login = tmpUsr.Email + } + case string(user.LoginUpdateAction): + // User updated the username field with a new email + update.Login = tmpUsr.Email + default: + return errors.New("trying to update email on unknown field") + } + + if err := s.us.Update(ctx, update); err != nil { + return err + } + + if err := s.ts.UpdateTempUserStatus( + ctx, + &tempuser.UpdateTempUserStatusCommand{Code: cmd.Code, Status: tempuser.TmpUserEmailUpdateCompleted}, + ); err != nil { + return err + } + + return nil +} diff --git a/pkg/services/user/userimpl/verifier_test.go b/pkg/services/user/userimpl/verifier_test.go index c57dc2fabf8..0e68088c56b 100644 --- a/pkg/services/user/userimpl/verifier_test.go +++ b/pkg/services/user/userimpl/verifier_test.go @@ -3,6 +3,7 @@ package userimpl import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" @@ -11,9 +12,10 @@ import ( "github.com/grafana/grafana/pkg/services/temp_user/tempusertest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user/usertest" + "github.com/grafana/grafana/pkg/setting" ) -func TestVerifier_VerifyEmail(t *testing.T) { +func TestVerifier_Start(t *testing.T) { ts := &tempusertest.FakeTempUserService{} us := &usertest.FakeUserService{} ns := notifications.MockNotificationService() @@ -24,10 +26,10 @@ func TestVerifier_VerifyEmail(t *testing.T) { updateCalled bool } - verifier := ProvideVerifier(us, ts, ns) + verifier := ProvideVerifier(setting.NewCfg(), us, ts, ns) t.Run("should error if email already exist for other user", func(t *testing.T) { us.ExpectedUser = &user.User{ID: 1} - err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{ + err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{ User: user.User{ID: 2}, Email: "some@email.com", Action: user.EmailUpdateAction, @@ -59,13 +61,13 @@ func TestVerifier_VerifyEmail(t *testing.T) { c.updateCalled = true return nil } - err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{ + err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{ User: user.User{ID: 2}, Email: "some@email.com", Action: user.EmailUpdateAction, }) - assert.ErrorIs(t, err, nil) + assert.NoError(t, err) assert.True(t, c.expireCalled) assert.True(t, c.createCalled) assert.True(t, c.updateCalled) @@ -94,7 +96,7 @@ func TestVerifier_VerifyEmail(t *testing.T) { c.updateCalled = true return nil } - err := verifier.VerifyEmail(context.Background(), user.VerifyEmailCommand{ + err := verifier.Start(context.Background(), user.StartVerifyEmailCommand{ User: user.User{ID: 2}, Email: "some@email.com", Action: user.EmailUpdateAction, @@ -106,3 +108,142 @@ func TestVerifier_VerifyEmail(t *testing.T) { assert.True(t, c.updateCalled) }) } + +func TestVerifier_Complete(t *testing.T) { + ts := &tempusertest.FakeTempUserService{} + us := &usertest.FakeUserService{} + ns := notifications.MockNotificationService() + + type calls struct { + updateCalled bool + updateStatusCalled bool + } + + cfg := setting.NewCfg() + cfg.VerificationEmailMaxLifetime = 1 * time.Hour + verifier := ProvideVerifier(cfg, us, ts, ns) + t.Run("should return error for invalid code", func(t *testing.T) { + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return nil, tempuser.ErrTempUserNotFound + } + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.ErrorIs(t, err, errInvalidCode) + }) + + t.Run("should return error when verification has wrong status", func(t *testing.T) { + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateCompleted, + }, nil + } + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.ErrorIs(t, err, errInvalidCode) + }) + + t.Run("should return error when verification email was never sent", func(t *testing.T) { + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateStarted, + EmailSent: false, + }, nil + } + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.ErrorIs(t, err, errInvalidCode) + }) + + t.Run("should return error when verification code has expired", func(t *testing.T) { + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateStarted, + EmailSent: true, + EmailSentOn: time.Now().Add(-10 * time.Hour), + }, nil + } + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.ErrorIs(t, err, errExpiredCode) + }) + + t.Run("should return error user connect to code don't exists", func(t *testing.T) { + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateStarted, + EmailSent: true, + EmailSentOn: time.Now(), + }, nil + } + us.ExpectedError = user.ErrUserNotFound + + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.ErrorIs(t, err, user.ErrUserNotFound) + }) + + t.Run("should update user email on valid code", func(t *testing.T) { + var c calls + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateStarted, + Name: string(user.EmailUpdateAction), + InvitedByID: 1, + Email: "updated@email.com", + EmailSent: true, + EmailSentOn: time.Now(), + }, nil + } + + ts.UpdateTempUserStatusFN = func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error { + c.updateStatusCalled = true + return nil + } + + us.ExpectedUser = &user.User{Email: "initial@email.com"} + us.ExpectedError = nil + us.UpdateFn = func(ctx context.Context, cmd *user.UpdateUserCommand) error { + c.updateCalled = true + assert.True(t, *cmd.EmailVerified) + assert.Equal(t, int64(1), cmd.UserID) + assert.Equal(t, "", cmd.Login) + assert.Equal(t, "updated@email.com", cmd.Email) + return nil + } + + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.NoError(t, err) + assert.True(t, c.updateCalled) + assert.True(t, c.updateStatusCalled) + }) + + t.Run("should update user email and login if login is an email on valid code", func(t *testing.T) { + var c calls + ts.GetTempUserByCodeFN = func(ctx context.Context, query *tempuser.GetTempUserByCodeQuery) (*tempuser.TempUserDTO, error) { + return &tempuser.TempUserDTO{ + Status: tempuser.TmpUserEmailUpdateStarted, + Name: string(user.EmailUpdateAction), + InvitedByID: 1, + Email: "updated@email.com", + EmailSent: true, + EmailSentOn: time.Now(), + }, nil + } + + ts.UpdateTempUserStatusFN = func(ctx context.Context, cmd *tempuser.UpdateTempUserStatusCommand) error { + c.updateStatusCalled = true + return nil + } + + us.ExpectedUser = &user.User{Email: "initial@email.com", Login: "other@email.com"} + us.ExpectedError = nil + us.UpdateFn = func(ctx context.Context, cmd *user.UpdateUserCommand) error { + c.updateCalled = true + assert.True(t, *cmd.EmailVerified) + assert.Equal(t, int64(1), cmd.UserID) + assert.Equal(t, "updated@email.com", cmd.Email) + assert.Equal(t, "updated@email.com", cmd.Login) + return nil + } + + err := verifier.Complete(context.Background(), user.CompleteEmailVerifyCommand{Code: "some-code"}) + assert.NoError(t, err) + assert.True(t, c.updateCalled) + assert.True(t, c.updateStatusCalled) + }) +} diff --git a/pkg/services/user/usertest/fake.go b/pkg/services/user/usertest/fake.go index 1a696a3a86c..b1e0aa81f3d 100644 --- a/pkg/services/user/usertest/fake.go +++ b/pkg/services/user/usertest/fake.go @@ -16,6 +16,7 @@ type FakeUserService struct { ExpectedUserProfileDTOs []*user.UserProfileDTO ExpectedUsageStats map[string]any + UpdateFn func(ctx context.Context, cmd *user.UpdateUserCommand) error GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error) DisableFn func(ctx context.Context, cmd *user.DisableUserCommand) error @@ -61,6 +62,9 @@ func (f *FakeUserService) GetByEmail(ctx context.Context, query *user.GetUserByE } func (f *FakeUserService) Update(ctx context.Context, cmd *user.UpdateUserCommand) error { + if f.UpdateFn != nil { + return f.UpdateFn(ctx, cmd) + } return f.ExpectedError }