Chore: Add context to temp user (#41284)

* Add context to temp user

* Remove xorm and InTransaction
This commit is contained in:
idafurjes 2021-11-04 11:17:07 +01:00 committed by GitHub
parent b82797d1b0
commit da5033f3fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 73 deletions

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -18,7 +19,7 @@ import (
func GetPendingOrgInvites(c *models.ReqContext) response.Response { func GetPendingOrgInvites(c *models.ReqContext) response.Response {
query := models.GetTempUsersQuery{OrgId: c.OrgId, Status: models.TmpUserInvitePending} query := models.GetTempUsersQuery{OrgId: c.OrgId, Status: models.TmpUserInvitePending}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
return response.Error(500, "Failed to get invites from db", err) return response.Error(500, "Failed to get invites from db", err)
} }
@ -62,7 +63,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
cmd.Role = inviteDto.Role cmd.Role = inviteDto.Role
cmd.RemoteAddr = c.Req.RemoteAddr cmd.RemoteAddr = c.Req.RemoteAddr
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to save invite to database", err) return response.Error(500, "Failed to save invite to database", err)
} }
@ -102,7 +103,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto *dtos.AddInviteForm) response.Response { func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto *dtos.AddInviteForm) response.Response {
// user exists, add org role // user exists, add org role
createOrgUserCmd := models.AddOrgUserCommand{OrgId: c.OrgId, UserId: user.Id, Role: inviteDto.Role} createOrgUserCmd := models.AddOrgUserCommand{OrgId: c.OrgId, UserId: user.Id, Role: inviteDto.Role}
if err := bus.Dispatch(&createOrgUserCmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &createOrgUserCmd); err != nil {
if errors.Is(err, models.ErrOrgUserAlreadyAdded) { if errors.Is(err, models.ErrOrgUserAlreadyAdded) {
return response.Error(412, fmt.Sprintf("User %s is already added to organization", inviteDto.LoginOrEmail), err) return response.Error(412, fmt.Sprintf("User %s is already added to organization", inviteDto.LoginOrEmail), err)
} }
@ -132,7 +133,7 @@ func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto
} }
func RevokeInvite(c *models.ReqContext) response.Response { func RevokeInvite(c *models.ReqContext) response.Response {
if ok, rsp := updateTempUserStatus(web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok { if ok, rsp := updateTempUserStatus(c.Req.Context(), web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok {
return rsp return rsp
} }
@ -144,7 +145,7 @@ func RevokeInvite(c *models.ReqContext) response.Response {
// If a (pending) invite is not found, 404 is returned. // If a (pending) invite is not found, 404 is returned.
func GetInviteInfoByCode(c *models.ReqContext) response.Response { func GetInviteInfoByCode(c *models.ReqContext) response.Response {
query := models.GetTempUserByCodeQuery{Code: web.Params(c.Req)[":code"]} query := models.GetTempUserByCodeQuery{Code: web.Params(c.Req)[":code"]}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) { if errors.Is(err, models.ErrTempUserNotFound) {
return response.Error(404, "Invite not found", nil) return response.Error(404, "Invite not found", nil)
} }
@ -167,7 +168,7 @@ func GetInviteInfoByCode(c *models.ReqContext) response.Response {
func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.CompleteInviteForm) response.Response { func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.CompleteInviteForm) response.Response {
query := models.GetTempUserByCodeQuery{Code: completeInvite.InviteCode} query := models.GetTempUserByCodeQuery{Code: completeInvite.InviteCode}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) { if errors.Is(err, models.ErrTempUserNotFound) {
return response.Error(404, "Invite not found", nil) return response.Error(404, "Invite not found", nil)
} }
@ -203,7 +204,7 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C
return response.Error(500, "failed to publish event", err) return response.Error(500, "failed to publish event", err)
} }
if ok, rsp := applyUserInvite(user, invite, true); !ok { if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, true); !ok {
return rsp return rsp
} }
@ -221,33 +222,33 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C
}) })
} }
func updateTempUserStatus(code string, status models.TempUserStatus) (bool, response.Response) { func updateTempUserStatus(ctx context.Context, code string, status models.TempUserStatus) (bool, response.Response) {
// update temp user status // update temp user status
updateTmpUserCmd := models.UpdateTempUserStatusCommand{Code: code, Status: status} updateTmpUserCmd := models.UpdateTempUserStatusCommand{Code: code, Status: status}
if err := bus.Dispatch(&updateTmpUserCmd); err != nil { if err := bus.DispatchCtx(ctx, &updateTmpUserCmd); err != nil {
return false, response.Error(500, "Failed to update invite status", err) return false, response.Error(500, "Failed to update invite status", err)
} }
return true, nil return true, nil
} }
func applyUserInvite(user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) { func applyUserInvite(ctx context.Context, user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) {
// add to org // add to org
addOrgUserCmd := models.AddOrgUserCommand{OrgId: invite.OrgId, UserId: user.Id, Role: invite.Role} addOrgUserCmd := models.AddOrgUserCommand{OrgId: invite.OrgId, UserId: user.Id, Role: invite.Role}
if err := bus.Dispatch(&addOrgUserCmd); err != nil { if err := bus.DispatchCtx(ctx, &addOrgUserCmd); err != nil {
if !errors.Is(err, models.ErrOrgUserAlreadyAdded) { if !errors.Is(err, models.ErrOrgUserAlreadyAdded) {
return false, response.Error(500, "Error while trying to create org user", err) return false, response.Error(500, "Error while trying to create org user", err)
} }
} }
// update temp user status // update temp user status
if ok, rsp := updateTempUserStatus(invite.Code, models.TmpUserCompleted); !ok { if ok, rsp := updateTempUserStatus(ctx, invite.Code, models.TmpUserCompleted); !ok {
return false, rsp return false, rsp
} }
if setActive { if setActive {
// set org to active // set org to active
if err := bus.Dispatch(&models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil { if err := bus.DispatchCtx(ctx, &models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil {
return false, response.Error(500, "Failed to set org as active", err) return false, response.Error(500, "Failed to set org as active", err)
} }
} }

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@ -44,7 +45,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response {
} }
cmd.RemoteAddr = c.Req.RemoteAddr cmd.RemoteAddr = c.Req.RemoteAddr
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to create signup", err) return response.Error(500, "Failed to create signup", err)
} }
@ -75,7 +76,7 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
// verify email // verify email
if setting.VerifyEmailEnabled { if setting.VerifyEmailEnabled {
if ok, rsp := verifyUserSignUpEmail(form.Email, form.Code); !ok { if ok, rsp := verifyUserSignUpEmail(c.Req.Context(), form.Email, form.Code); !ok {
return rsp return rsp
} }
createUserCmd.EmailVerified = true createUserCmd.EmailVerified = true
@ -99,19 +100,19 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
} }
// mark temp user as completed // mark temp user as completed
if ok, rsp := updateTempUserStatus(form.Code, models.TmpUserCompleted); !ok { if ok, rsp := updateTempUserStatus(c.Req.Context(), form.Code, models.TmpUserCompleted); !ok {
return rsp return rsp
} }
// check for pending invites // check for pending invites
invitesQuery := models.GetTempUsersQuery{Email: form.Email, Status: models.TmpUserInvitePending} invitesQuery := models.GetTempUsersQuery{Email: form.Email, Status: models.TmpUserInvitePending}
if err := bus.Dispatch(&invitesQuery); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &invitesQuery); err != nil {
return response.Error(500, "Failed to query database for invites", err) return response.Error(500, "Failed to query database for invites", err)
} }
apiResponse := util.DynMap{"message": "User sign up completed successfully", "code": "redirect-to-landing-page"} apiResponse := util.DynMap{"message": "User sign up completed successfully", "code": "redirect-to-landing-page"}
for _, invite := range invitesQuery.Result { for _, invite := range invitesQuery.Result {
if ok, rsp := applyUserInvite(user, invite, false); !ok { if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, false); !ok {
return rsp return rsp
} }
apiResponse["code"] = "redirect-to-select-org" apiResponse["code"] = "redirect-to-select-org"
@ -127,10 +128,10 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For
return response.JSON(200, apiResponse) return response.JSON(200, apiResponse)
} }
func verifyUserSignUpEmail(email string, code string) (bool, response.Response) { func verifyUserSignUpEmail(ctx context.Context, email string, code string) (bool, response.Response) {
query := models.GetTempUserByCodeQuery{Code: code} query := models.GetTempUserByCodeQuery{Code: code}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
if errors.Is(err, models.ErrTempUserNotFound) { if errors.Is(err, models.ErrTempUserNotFound) {
return false, response.Error(404, "Invalid email verification code", nil) return false, response.Error(404, "Invalid email verification code", nil)
} }

View File

@ -60,7 +60,7 @@ func (s *OSSService) SearchUser(c *models.ReqContext) (*models.SearchUsersQuery,
} }
query := &models.SearchUsersQuery{Query: searchQuery, Filters: filters, Page: page, Limit: perPage} query := &models.SearchUsersQuery{Query: searchQuery, Filters: filters, Page: page, Limit: perPage}
if err := s.bus.Dispatch(query); err != nil { if err := s.bus.DispatchCtx(c.Req.Context(), query); err != nil {
return nil, err return nil, err
} }

View File

@ -118,6 +118,7 @@ func newSQLStore(cfg *setting.Cfg, cacheService *localcache.CacheService, bus bu
ss.addOrgUsersQueryAndCommandHandlers() ss.addOrgUsersQueryAndCommandHandlers()
ss.addStarQueryAndCommandHandlers() ss.addStarQueryAndCommandHandlers()
ss.addAlertQueryAndCommandHandlers() ss.addAlertQueryAndCommandHandlers()
ss.addTempUserQueryAndCommandHandlers()
// if err := ss.Reset(); err != nil { // if err := ss.Reset(); err != nil {
// return nil, err // return nil, err

View File

@ -1,31 +1,32 @@
package sqlstore package sqlstore
import ( import (
"context"
"time" "time"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
) )
func init() { func (ss *SQLStore) addTempUserQueryAndCommandHandlers() {
bus.AddHandler("sql", CreateTempUser) bus.AddHandlerCtx("sql", ss.CreateTempUser)
bus.AddHandler("sql", GetTempUsersQuery) bus.AddHandlerCtx("sql", ss.GetTempUsersQuery)
bus.AddHandler("sql", UpdateTempUserStatus) bus.AddHandlerCtx("sql", ss.UpdateTempUserStatus)
bus.AddHandler("sql", GetTempUserByCode) bus.AddHandlerCtx("sql", ss.GetTempUserByCode)
bus.AddHandler("sql", UpdateTempUserWithEmailSent) bus.AddHandlerCtx("sql", ss.UpdateTempUserWithEmailSent)
bus.AddHandler("sql", ExpireOldUserInvites) bus.AddHandlerCtx("sql", ss.ExpireOldUserInvites)
} }
func UpdateTempUserStatus(cmd *models.UpdateTempUserStatusCommand) error { func (ss *SQLStore) UpdateTempUserStatus(ctx context.Context, cmd *models.UpdateTempUserStatusCommand) error {
return inTransaction(func(sess *DBSession) error { return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var rawSQL = "UPDATE temp_user SET status=? WHERE code=?" var rawSQL = "UPDATE temp_user SET status=? WHERE code=?"
_, err := sess.Exec(rawSQL, string(cmd.Status), cmd.Code) _, err := sess.Exec(rawSQL, string(cmd.Status), cmd.Code)
return err return err
}) })
} }
func CreateTempUser(cmd *models.CreateTempUserCommand) error { func (ss *SQLStore) CreateTempUser(ctx context.Context, cmd *models.CreateTempUserCommand) error {
return inTransaction(func(sess *DBSession) error { return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
// create user // create user
user := &models.TempUser{ user := &models.TempUser{
Email: cmd.Email, Email: cmd.Email,
@ -46,12 +47,13 @@ func CreateTempUser(cmd *models.CreateTempUserCommand) error {
} }
cmd.Result = user cmd.Result = user
return nil return nil
}) })
} }
func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand) error { func (ss *SQLStore) UpdateTempUserWithEmailSent(ctx context.Context, cmd *models.UpdateTempUserWithEmailSentCommand) error {
return inTransaction(func(sess *DBSession) error { return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
user := &models.TempUser{ user := &models.TempUser{
EmailSent: true, EmailSent: true,
EmailSentOn: time.Now(), EmailSentOn: time.Now(),
@ -63,8 +65,9 @@ func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand)
}) })
} }
func GetTempUsersQuery(query *models.GetTempUsersQuery) error { func (ss *SQLStore) GetTempUsersQuery(ctx context.Context, query *models.GetTempUsersQuery) error {
rawSQL := `SELECT return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
rawSQL := `SELECT
tu.id as id, tu.id as id,
tu.org_id as org_id, tu.org_id as org_id,
tu.email as email, tu.email as email,
@ -81,28 +84,30 @@ func GetTempUsersQuery(query *models.GetTempUsersQuery) error {
FROM ` + dialect.Quote("temp_user") + ` as tu FROM ` + dialect.Quote("temp_user") + ` as tu
LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id
WHERE tu.status=?` WHERE tu.status=?`
params := []interface{}{string(query.Status)} params := []interface{}{string(query.Status)}
if query.OrgId > 0 { if query.OrgId > 0 {
rawSQL += ` AND tu.org_id=?` rawSQL += ` AND tu.org_id=?`
params = append(params, query.OrgId) params = append(params, query.OrgId)
} }
if query.Email != "" { if query.Email != "" {
rawSQL += ` AND tu.email=?` rawSQL += ` AND tu.email=?`
params = append(params, query.Email) params = append(params, query.Email)
} }
rawSQL += " ORDER BY tu.created desc" rawSQL += " ORDER BY tu.created desc"
query.Result = make([]*models.TempUserDTO, 0) query.Result = make([]*models.TempUserDTO, 0)
sess := x.SQL(rawSQL, params...) sess := dbSess.SQL(rawSQL, params...)
err := sess.Find(&query.Result) err := sess.Find(&query.Result)
return err return err
})
} }
func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error { func (ss *SQLStore) GetTempUserByCode(ctx context.Context, query *models.GetTempUserByCodeQuery) error {
var rawSQL = `SELECT return ss.WithDbSession(ctx, func(dbSess *DBSession) error {
var rawSQL = `SELECT
tu.id as id, tu.id as id,
tu.org_id as org_id, tu.org_id as org_id,
tu.email as email, tu.email as email,
@ -120,22 +125,23 @@ func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error {
LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id
WHERE tu.code=?` WHERE tu.code=?`
var tempUser models.TempUserDTO var tempUser models.TempUserDTO
sess := x.SQL(rawSQL, query.Code) sess := dbSess.SQL(rawSQL, query.Code)
has, err := sess.Get(&tempUser) has, err := sess.Get(&tempUser)
if err != nil { if err != nil {
return err
} else if !has {
return models.ErrTempUserNotFound
}
query.Result = &tempUser
return err return err
} else if !has { })
return models.ErrTempUserNotFound
}
query.Result = &tempUser
return err
} }
func ExpireOldUserInvites(cmd *models.ExpireTempUsersCommand) error { func (ss *SQLStore) ExpireOldUserInvites(ctx context.Context, cmd *models.ExpireTempUsersCommand) error {
return inTransaction(func(sess *DBSession) error { return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var rawSQL = "UPDATE temp_user SET status = ?, updated = ? WHERE created <= ? AND status in (?, ?)" var rawSQL = "UPDATE temp_user SET status = ?, updated = ? WHERE created <= ? AND status in (?, ?)"
if result, err := sess.Exec(rawSQL, string(models.TmpUserExpired), time.Now().Unix(), cmd.OlderThan.Unix(), string(models.TmpUserSignUpStarted), string(models.TmpUserInvitePending)); err != nil { if result, err := sess.Exec(rawSQL, string(models.TmpUserExpired), time.Now().Unix(), cmd.OlderThan.Unix(), string(models.TmpUserSignUpStarted), string(models.TmpUserInvitePending)); err != nil {
return err return err

View File

@ -4,6 +4,7 @@
package sqlstore package sqlstore
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -13,6 +14,7 @@ import (
) )
func TestTempUserCommandsAndQueries(t *testing.T) { func TestTempUserCommandsAndQueries(t *testing.T) {
ss := InitTestDB(t)
cmd := models.CreateTempUserCommand{ cmd := models.CreateTempUserCommand{
OrgId: 2256, OrgId: 2256,
Name: "hello", Name: "hello",
@ -22,14 +24,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
} }
setup := func(t *testing.T) { setup := func(t *testing.T) {
InitTestDB(t) InitTestDB(t)
err := CreateTempUser(&cmd) err := ss.CreateTempUser(context.Background(), &cmd)
require.Nil(t, err) require.Nil(t, err)
} }
t.Run("Should be able to get temp users by org id", func(t *testing.T) { t.Run("Should be able to get temp users by org id", func(t *testing.T) {
setup(t) setup(t)
query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending} query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending}
err := GetTempUsersQuery(&query) err := ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -38,7 +40,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able to get temp users by email", func(t *testing.T) { t.Run("Should be able to get temp users by email", func(t *testing.T) {
setup(t) setup(t)
query := models.GetTempUsersQuery{Email: "e@as.co", Status: models.TmpUserInvitePending} query := models.GetTempUsersQuery{Email: "e@as.co", Status: models.TmpUserInvitePending}
err := GetTempUsersQuery(&query) err := ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, 1, len(query.Result)) require.Equal(t, 1, len(query.Result))
@ -47,7 +49,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able to get temp users by code", func(t *testing.T) { t.Run("Should be able to get temp users by code", func(t *testing.T) {
setup(t) setup(t)
query := models.GetTempUserByCodeQuery{Code: "asd"} query := models.GetTempUserByCodeQuery{Code: "asd"}
err := GetTempUserByCode(&query) err := ss.GetTempUserByCode(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "hello", query.Result.Name) require.Equal(t, "hello", query.Result.Name)
@ -56,18 +58,18 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
t.Run("Should be able update status", func(t *testing.T) { t.Run("Should be able update status", func(t *testing.T) {
setup(t) setup(t)
cmd2 := models.UpdateTempUserStatusCommand{Code: "asd", Status: models.TmpUserRevoked} cmd2 := models.UpdateTempUserStatusCommand{Code: "asd", Status: models.TmpUserRevoked}
err := UpdateTempUserStatus(&cmd2) err := ss.UpdateTempUserStatus(context.Background(), &cmd2)
require.Nil(t, err) require.Nil(t, err)
}) })
t.Run("Should be able update email sent and email sent on", func(t *testing.T) { t.Run("Should be able update email sent and email sent on", func(t *testing.T) {
setup(t) setup(t)
cmd2 := models.UpdateTempUserWithEmailSentCommand{Code: cmd.Result.Code} cmd2 := models.UpdateTempUserWithEmailSentCommand{Code: cmd.Result.Code}
err := UpdateTempUserWithEmailSent(&cmd2) err := ss.UpdateTempUserWithEmailSent(context.Background(), &cmd2)
require.Nil(t, err) require.Nil(t, err)
query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending} query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending}
err = GetTempUsersQuery(&query) err = ss.GetTempUsersQuery(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.True(t, query.Result[0].EmailSent) require.True(t, query.Result[0].EmailSent)
@ -78,14 +80,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) {
setup(t) setup(t)
createdAt := time.Unix(cmd.Result.Created, 0) createdAt := time.Unix(cmd.Result.Created, 0)
cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)} cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)}
err := ExpireOldUserInvites(&cmd2) err := ss.ExpireOldUserInvites(context.Background(), &cmd2)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(1), cmd2.NumExpired) require.Equal(t, int64(1), cmd2.NumExpired)
t.Run("Should do nothing when no temp users to expire", func(t *testing.T) { t.Run("Should do nothing when no temp users to expire", func(t *testing.T) {
createdAt := time.Unix(cmd.Result.Created, 0) createdAt := time.Unix(cmd.Result.Created, 0)
cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)} cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)}
err := ExpireOldUserInvites(&cmd2) err := ss.ExpireOldUserInvites(context.Background(), &cmd2)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), cmd2.NumExpired) require.Equal(t, int64(0), cmd2.NumExpired)
}) })