diff --git a/pkg/api/org_invite.go b/pkg/api/org_invite.go index 0c718dee7c9..51b567d6464 100644 --- a/pkg/api/org_invite.go +++ b/pkg/api/org_invite.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "fmt" @@ -18,7 +19,7 @@ import ( func GetPendingOrgInvites(c *models.ReqContext) response.Response { query := models.GetTempUsersQuery{OrgId: c.OrgId, Status: models.TmpUserInvitePending} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil { return response.Error(500, "Failed to get invites from db", err) } @@ -62,7 +63,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R cmd.Role = inviteDto.Role cmd.RemoteAddr = c.Req.RemoteAddr - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to save invite to database", err) } @@ -102,7 +103,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto *dtos.AddInviteForm) response.Response { // user exists, add org role createOrgUserCmd := models.AddOrgUserCommand{OrgId: c.OrgId, UserId: user.Id, Role: inviteDto.Role} - if err := bus.Dispatch(&createOrgUserCmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &createOrgUserCmd); err != nil { if errors.Is(err, models.ErrOrgUserAlreadyAdded) { return response.Error(412, fmt.Sprintf("User %s is already added to organization", inviteDto.LoginOrEmail), err) } @@ -132,7 +133,7 @@ func inviteExistingUserToOrg(c *models.ReqContext, user *models.User, inviteDto } func RevokeInvite(c *models.ReqContext) response.Response { - if ok, rsp := updateTempUserStatus(web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok { + if ok, rsp := updateTempUserStatus(c.Req.Context(), web.Params(c.Req)[":code"], models.TmpUserRevoked); !ok { return rsp } @@ -144,7 +145,7 @@ func RevokeInvite(c *models.ReqContext) response.Response { // If a (pending) invite is not found, 404 is returned. func GetInviteInfoByCode(c *models.ReqContext) response.Response { query := models.GetTempUserByCodeQuery{Code: web.Params(c.Req)[":code"]} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrTempUserNotFound) { return response.Error(404, "Invite not found", nil) } @@ -167,7 +168,7 @@ func GetInviteInfoByCode(c *models.ReqContext) response.Response { func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.CompleteInviteForm) response.Response { query := models.GetTempUserByCodeQuery{Code: completeInvite.InviteCode} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrTempUserNotFound) { return response.Error(404, "Invite not found", nil) } @@ -203,7 +204,7 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C return response.Error(500, "failed to publish event", err) } - if ok, rsp := applyUserInvite(user, invite, true); !ok { + if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, true); !ok { return rsp } @@ -221,33 +222,33 @@ func (hs *HTTPServer) CompleteInvite(c *models.ReqContext, completeInvite dtos.C }) } -func updateTempUserStatus(code string, status models.TempUserStatus) (bool, response.Response) { +func updateTempUserStatus(ctx context.Context, code string, status models.TempUserStatus) (bool, response.Response) { // update temp user status updateTmpUserCmd := models.UpdateTempUserStatusCommand{Code: code, Status: status} - if err := bus.Dispatch(&updateTmpUserCmd); err != nil { + if err := bus.DispatchCtx(ctx, &updateTmpUserCmd); err != nil { return false, response.Error(500, "Failed to update invite status", err) } return true, nil } -func applyUserInvite(user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) { +func applyUserInvite(ctx context.Context, user *models.User, invite *models.TempUserDTO, setActive bool) (bool, response.Response) { // add to org addOrgUserCmd := models.AddOrgUserCommand{OrgId: invite.OrgId, UserId: user.Id, Role: invite.Role} - if err := bus.Dispatch(&addOrgUserCmd); err != nil { + if err := bus.DispatchCtx(ctx, &addOrgUserCmd); err != nil { if !errors.Is(err, models.ErrOrgUserAlreadyAdded) { return false, response.Error(500, "Error while trying to create org user", err) } } // update temp user status - if ok, rsp := updateTempUserStatus(invite.Code, models.TmpUserCompleted); !ok { + if ok, rsp := updateTempUserStatus(ctx, invite.Code, models.TmpUserCompleted); !ok { return false, rsp } if setActive { // set org to active - if err := bus.Dispatch(&models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil { + if err := bus.DispatchCtx(ctx, &models.SetUsingOrgCommand{OrgId: invite.OrgId, UserId: user.Id}); err != nil { return false, response.Error(500, "Failed to set org as active", err) } } diff --git a/pkg/api/signup.go b/pkg/api/signup.go index 4b4a237db72..6b9331874d0 100644 --- a/pkg/api/signup.go +++ b/pkg/api/signup.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "github.com/grafana/grafana/pkg/api/dtos" @@ -44,7 +45,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response { } cmd.RemoteAddr = c.Req.RemoteAddr - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to create signup", err) } @@ -75,7 +76,7 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For // verify email if setting.VerifyEmailEnabled { - if ok, rsp := verifyUserSignUpEmail(form.Email, form.Code); !ok { + if ok, rsp := verifyUserSignUpEmail(c.Req.Context(), form.Email, form.Code); !ok { return rsp } createUserCmd.EmailVerified = true @@ -99,19 +100,19 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For } // mark temp user as completed - if ok, rsp := updateTempUserStatus(form.Code, models.TmpUserCompleted); !ok { + if ok, rsp := updateTempUserStatus(c.Req.Context(), form.Code, models.TmpUserCompleted); !ok { return rsp } // check for pending invites invitesQuery := models.GetTempUsersQuery{Email: form.Email, Status: models.TmpUserInvitePending} - if err := bus.Dispatch(&invitesQuery); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &invitesQuery); err != nil { return response.Error(500, "Failed to query database for invites", err) } apiResponse := util.DynMap{"message": "User sign up completed successfully", "code": "redirect-to-landing-page"} for _, invite := range invitesQuery.Result { - if ok, rsp := applyUserInvite(user, invite, false); !ok { + if ok, rsp := applyUserInvite(c.Req.Context(), user, invite, false); !ok { return rsp } apiResponse["code"] = "redirect-to-select-org" @@ -127,10 +128,10 @@ func (hs *HTTPServer) SignUpStep2(c *models.ReqContext, form dtos.SignUpStep2For return response.JSON(200, apiResponse) } -func verifyUserSignUpEmail(email string, code string) (bool, response.Response) { +func verifyUserSignUpEmail(ctx context.Context, email string, code string) (bool, response.Response) { query := models.GetTempUserByCodeQuery{Code: code} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(ctx, &query); err != nil { if errors.Is(err, models.ErrTempUserNotFound) { return false, response.Error(404, "Invalid email verification code", nil) } diff --git a/pkg/services/searchusers/searchusers.go b/pkg/services/searchusers/searchusers.go index 1d72508e97b..c7070d26cee 100644 --- a/pkg/services/searchusers/searchusers.go +++ b/pkg/services/searchusers/searchusers.go @@ -60,7 +60,7 @@ func (s *OSSService) SearchUser(c *models.ReqContext) (*models.SearchUsersQuery, } query := &models.SearchUsersQuery{Query: searchQuery, Filters: filters, Page: page, Limit: perPage} - if err := s.bus.Dispatch(query); err != nil { + if err := s.bus.DispatchCtx(c.Req.Context(), query); err != nil { return nil, err } diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index 8aa95646069..23900e94f6f 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -118,6 +118,7 @@ func newSQLStore(cfg *setting.Cfg, cacheService *localcache.CacheService, bus bu ss.addOrgUsersQueryAndCommandHandlers() ss.addStarQueryAndCommandHandlers() ss.addAlertQueryAndCommandHandlers() + ss.addTempUserQueryAndCommandHandlers() // if err := ss.Reset(); err != nil { // return nil, err diff --git a/pkg/services/sqlstore/temp_user.go b/pkg/services/sqlstore/temp_user.go index 3fc884e516a..d78cd758170 100644 --- a/pkg/services/sqlstore/temp_user.go +++ b/pkg/services/sqlstore/temp_user.go @@ -1,31 +1,32 @@ package sqlstore import ( + "context" "time" "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" ) -func init() { - bus.AddHandler("sql", CreateTempUser) - bus.AddHandler("sql", GetTempUsersQuery) - bus.AddHandler("sql", UpdateTempUserStatus) - bus.AddHandler("sql", GetTempUserByCode) - bus.AddHandler("sql", UpdateTempUserWithEmailSent) - bus.AddHandler("sql", ExpireOldUserInvites) +func (ss *SQLStore) addTempUserQueryAndCommandHandlers() { + bus.AddHandlerCtx("sql", ss.CreateTempUser) + bus.AddHandlerCtx("sql", ss.GetTempUsersQuery) + bus.AddHandlerCtx("sql", ss.UpdateTempUserStatus) + bus.AddHandlerCtx("sql", ss.GetTempUserByCode) + bus.AddHandlerCtx("sql", ss.UpdateTempUserWithEmailSent) + bus.AddHandlerCtx("sql", ss.ExpireOldUserInvites) } -func UpdateTempUserStatus(cmd *models.UpdateTempUserStatusCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) UpdateTempUserStatus(ctx context.Context, cmd *models.UpdateTempUserStatusCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { var rawSQL = "UPDATE temp_user SET status=? WHERE code=?" _, err := sess.Exec(rawSQL, string(cmd.Status), cmd.Code) return err }) } -func CreateTempUser(cmd *models.CreateTempUserCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) CreateTempUser(ctx context.Context, cmd *models.CreateTempUserCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { // create user user := &models.TempUser{ Email: cmd.Email, @@ -46,12 +47,13 @@ func CreateTempUser(cmd *models.CreateTempUserCommand) error { } cmd.Result = user + return nil }) } -func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) UpdateTempUserWithEmailSent(ctx context.Context, cmd *models.UpdateTempUserWithEmailSentCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { user := &models.TempUser{ EmailSent: true, EmailSentOn: time.Now(), @@ -63,8 +65,9 @@ func UpdateTempUserWithEmailSent(cmd *models.UpdateTempUserWithEmailSentCommand) }) } -func GetTempUsersQuery(query *models.GetTempUsersQuery) error { - rawSQL := `SELECT +func (ss *SQLStore) GetTempUsersQuery(ctx context.Context, query *models.GetTempUsersQuery) error { + return ss.WithDbSession(ctx, func(dbSess *DBSession) error { + rawSQL := `SELECT tu.id as id, tu.org_id as org_id, tu.email as email, @@ -81,28 +84,30 @@ func GetTempUsersQuery(query *models.GetTempUsersQuery) error { FROM ` + dialect.Quote("temp_user") + ` as tu LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id WHERE tu.status=?` - params := []interface{}{string(query.Status)} + params := []interface{}{string(query.Status)} - if query.OrgId > 0 { - rawSQL += ` AND tu.org_id=?` - params = append(params, query.OrgId) - } + if query.OrgId > 0 { + rawSQL += ` AND tu.org_id=?` + params = append(params, query.OrgId) + } - if query.Email != "" { - rawSQL += ` AND tu.email=?` - params = append(params, query.Email) - } + if query.Email != "" { + rawSQL += ` AND tu.email=?` + params = append(params, query.Email) + } - rawSQL += " ORDER BY tu.created desc" + rawSQL += " ORDER BY tu.created desc" - query.Result = make([]*models.TempUserDTO, 0) - sess := x.SQL(rawSQL, params...) - err := sess.Find(&query.Result) - return err + query.Result = make([]*models.TempUserDTO, 0) + sess := dbSess.SQL(rawSQL, params...) + err := sess.Find(&query.Result) + return err + }) } -func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error { - var rawSQL = `SELECT +func (ss *SQLStore) GetTempUserByCode(ctx context.Context, query *models.GetTempUserByCodeQuery) error { + return ss.WithDbSession(ctx, func(dbSess *DBSession) error { + var rawSQL = `SELECT tu.id as id, tu.org_id as org_id, tu.email as email, @@ -120,22 +125,23 @@ func GetTempUserByCode(query *models.GetTempUserByCodeQuery) error { LEFT OUTER JOIN ` + dialect.Quote("user") + ` as u on u.id = tu.invited_by_user_id WHERE tu.code=?` - var tempUser models.TempUserDTO - sess := x.SQL(rawSQL, query.Code) - has, err := sess.Get(&tempUser) + var tempUser models.TempUserDTO + sess := dbSess.SQL(rawSQL, query.Code) + has, err := sess.Get(&tempUser) - if err != nil { + if err != nil { + return err + } else if !has { + return models.ErrTempUserNotFound + } + + query.Result = &tempUser return err - } else if !has { - return models.ErrTempUserNotFound - } - - query.Result = &tempUser - return err + }) } -func ExpireOldUserInvites(cmd *models.ExpireTempUsersCommand) error { - return inTransaction(func(sess *DBSession) error { +func (ss *SQLStore) ExpireOldUserInvites(ctx context.Context, cmd *models.ExpireTempUsersCommand) error { + return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error { var rawSQL = "UPDATE temp_user SET status = ?, updated = ? WHERE created <= ? AND status in (?, ?)" if result, err := sess.Exec(rawSQL, string(models.TmpUserExpired), time.Now().Unix(), cmd.OlderThan.Unix(), string(models.TmpUserSignUpStarted), string(models.TmpUserInvitePending)); err != nil { return err diff --git a/pkg/services/sqlstore/temp_user_test.go b/pkg/services/sqlstore/temp_user_test.go index 40b11524991..28f88e084aa 100644 --- a/pkg/services/sqlstore/temp_user_test.go +++ b/pkg/services/sqlstore/temp_user_test.go @@ -4,6 +4,7 @@ package sqlstore import ( + "context" "testing" "time" @@ -13,6 +14,7 @@ import ( ) func TestTempUserCommandsAndQueries(t *testing.T) { + ss := InitTestDB(t) cmd := models.CreateTempUserCommand{ OrgId: 2256, Name: "hello", @@ -22,14 +24,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) { } setup := func(t *testing.T) { InitTestDB(t) - err := CreateTempUser(&cmd) + err := ss.CreateTempUser(context.Background(), &cmd) require.Nil(t, err) } t.Run("Should be able to get temp users by org id", func(t *testing.T) { setup(t) query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending} - err := GetTempUsersQuery(&query) + err := ss.GetTempUsersQuery(context.Background(), &query) require.Nil(t, err) require.Equal(t, 1, len(query.Result)) @@ -38,7 +40,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) { t.Run("Should be able to get temp users by email", func(t *testing.T) { setup(t) query := models.GetTempUsersQuery{Email: "e@as.co", Status: models.TmpUserInvitePending} - err := GetTempUsersQuery(&query) + err := ss.GetTempUsersQuery(context.Background(), &query) require.Nil(t, err) require.Equal(t, 1, len(query.Result)) @@ -47,7 +49,7 @@ func TestTempUserCommandsAndQueries(t *testing.T) { t.Run("Should be able to get temp users by code", func(t *testing.T) { setup(t) query := models.GetTempUserByCodeQuery{Code: "asd"} - err := GetTempUserByCode(&query) + err := ss.GetTempUserByCode(context.Background(), &query) require.Nil(t, err) require.Equal(t, "hello", query.Result.Name) @@ -56,18 +58,18 @@ func TestTempUserCommandsAndQueries(t *testing.T) { t.Run("Should be able update status", func(t *testing.T) { setup(t) cmd2 := models.UpdateTempUserStatusCommand{Code: "asd", Status: models.TmpUserRevoked} - err := UpdateTempUserStatus(&cmd2) + err := ss.UpdateTempUserStatus(context.Background(), &cmd2) require.Nil(t, err) }) t.Run("Should be able update email sent and email sent on", func(t *testing.T) { setup(t) cmd2 := models.UpdateTempUserWithEmailSentCommand{Code: cmd.Result.Code} - err := UpdateTempUserWithEmailSent(&cmd2) + err := ss.UpdateTempUserWithEmailSent(context.Background(), &cmd2) require.Nil(t, err) query := models.GetTempUsersQuery{OrgId: 2256, Status: models.TmpUserInvitePending} - err = GetTempUsersQuery(&query) + err = ss.GetTempUsersQuery(context.Background(), &query) require.Nil(t, err) require.True(t, query.Result[0].EmailSent) @@ -78,14 +80,14 @@ func TestTempUserCommandsAndQueries(t *testing.T) { setup(t) createdAt := time.Unix(cmd.Result.Created, 0) cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)} - err := ExpireOldUserInvites(&cmd2) + err := ss.ExpireOldUserInvites(context.Background(), &cmd2) require.Nil(t, err) require.Equal(t, int64(1), cmd2.NumExpired) t.Run("Should do nothing when no temp users to expire", func(t *testing.T) { createdAt := time.Unix(cmd.Result.Created, 0) cmd2 := models.ExpireTempUsersCommand{OlderThan: createdAt.Add(1 * time.Second)} - err := ExpireOldUserInvites(&cmd2) + err := ss.ExpireOldUserInvites(context.Background(), &cmd2) require.Nil(t, err) require.Equal(t, int64(0), cmd2.NumExpired) })