diff --git a/pkg/api/user.go b/pkg/api/user.go index 9bbe536e6af..9f6341b23a8 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "github.com/grafana/grafana/pkg/api/dtos" @@ -13,18 +14,18 @@ import ( // GET /api/user (current authenticated user) func GetSignedInUser(c *models.ReqContext) response.Response { - return getUserUserProfile(c.UserId) + return getUserUserProfile(c.Req.Context(), c.UserId) } // GET /api/users/:id func GetUserByID(c *models.ReqContext) response.Response { - return getUserUserProfile(c.ParamsInt64(":id")) + return getUserUserProfile(c.Req.Context(), c.ParamsInt64(":id")) } -func getUserUserProfile(userID int64) response.Response { +func getUserUserProfile(ctx context.Context, userID int64) response.Response { query := models.GetUserProfileQuery{UserId: userID} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(ctx, &query); err != nil { if errors.Is(err, models.ErrUserNotFound) { return response.Error(404, models.ErrUserNotFound.Error(), nil) } @@ -33,7 +34,7 @@ func getUserUserProfile(userID int64) response.Response { getAuthQuery := models.GetAuthInfoQuery{UserId: userID} query.Result.AuthLabels = []string{} - if err := bus.Dispatch(&getAuthQuery); err == nil { + if err := bus.DispatchCtx(ctx, &getAuthQuery); err == nil { authLabel := GetAuthProviderLabel(getAuthQuery.Result.AuthModule) query.Result.AuthLabels = append(query.Result.AuthLabels, authLabel) query.Result.IsExternal = true @@ -47,7 +48,7 @@ func getUserUserProfile(userID int64) response.Response { // GET /api/users/lookup func GetUserByLoginOrEmail(c *models.ReqContext) response.Response { query := models.GetUserByLoginQuery{LoginOrEmail: c.Query("loginOrEmail")} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrUserNotFound) { return response.Error(404, models.ErrUserNotFound.Error(), nil) } @@ -79,13 +80,13 @@ func UpdateSignedInUser(c *models.ReqContext, cmd models.UpdateUserCommand) resp } } cmd.UserId = c.UserId - return handleUpdateUser(cmd) + return handleUpdateUser(c.Req.Context(), cmd) } // POST /api/users/:id func UpdateUser(c *models.ReqContext, cmd models.UpdateUserCommand) response.Response { cmd.UserId = c.ParamsInt64(":id") - return handleUpdateUser(cmd) + return handleUpdateUser(c.Req.Context(), cmd) } // POST /api/users/:id/using/:orgId @@ -93,20 +94,20 @@ func UpdateUserActiveOrg(c *models.ReqContext) response.Response { userID := c.ParamsInt64(":id") orgID := c.ParamsInt64(":orgId") - if !validateUsingOrg(userID, orgID) { + if !validateUsingOrg(c.Req.Context(), userID, orgID) { return response.Error(401, "Not a valid organization", nil) } cmd := models.SetUsingOrgCommand{UserId: userID, OrgId: orgID} - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } return response.Success("Active organization changed") } -func handleUpdateUser(cmd models.UpdateUserCommand) response.Response { +func handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) response.Response { if len(cmd.Login) == 0 { cmd.Login = cmd.Email if len(cmd.Login) == 0 { @@ -114,7 +115,7 @@ func handleUpdateUser(cmd models.UpdateUserCommand) response.Response { } } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(ctx, &cmd); err != nil { return response.Error(500, "Failed to update user", err) } @@ -123,23 +124,23 @@ func handleUpdateUser(cmd models.UpdateUserCommand) response.Response { // GET /api/user/orgs func GetSignedInUserOrgList(c *models.ReqContext) response.Response { - return getUserOrgList(c.UserId) + return getUserOrgList(c.Req.Context(), c.UserId) } // GET /api/user/teams func GetSignedInUserTeamList(c *models.ReqContext) response.Response { - return getUserTeamList(c.OrgId, c.UserId) + return getUserTeamList(c.Req.Context(), c.OrgId, c.UserId) } // GET /api/users/:id/teams func GetUserTeams(c *models.ReqContext) response.Response { - return getUserTeamList(c.OrgId, c.ParamsInt64(":id")) + return getUserTeamList(c.Req.Context(), c.OrgId, c.ParamsInt64(":id")) } -func getUserTeamList(orgID int64, userID int64) response.Response { +func getUserTeamList(ctx context.Context, orgID int64, userID int64) response.Response { query := models.GetTeamsByUserQuery{OrgId: orgID, UserId: userID} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(ctx, &query); err != nil { return response.Error(500, "Failed to get user teams", err) } @@ -151,23 +152,23 @@ func getUserTeamList(orgID int64, userID int64) response.Response { // GET /api/users/:id/orgs func GetUserOrgList(c *models.ReqContext) response.Response { - return getUserOrgList(c.ParamsInt64(":id")) + return getUserOrgList(c.Req.Context(), c.ParamsInt64(":id")) } -func getUserOrgList(userID int64) response.Response { +func getUserOrgList(ctx context.Context, userID int64) response.Response { query := models.GetUserOrgListQuery{UserId: userID} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(ctx, &query); err != nil { return response.Error(500, "Failed to get user organizations", err) } return response.JSON(200, query.Result) } -func validateUsingOrg(userID int64, orgID int64) bool { +func validateUsingOrg(ctx context.Context, userID int64, orgID int64) bool { query := models.GetUserOrgListQuery{UserId: userID} - if err := bus.Dispatch(&query); err != nil { + if err := bus.DispatchCtx(ctx, &query); err != nil { return false } @@ -186,13 +187,13 @@ func validateUsingOrg(userID int64, orgID int64) bool { func UserSetUsingOrg(c *models.ReqContext) response.Response { orgID := c.ParamsInt64(":id") - if !validateUsingOrg(c.UserId, orgID) { + if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) { return response.Error(401, "Not a valid organization", nil) } cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } @@ -203,13 +204,13 @@ func UserSetUsingOrg(c *models.ReqContext) response.Response { func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *models.ReqContext) { orgID := c.ParamsInt64(":id") - if !validateUsingOrg(c.UserId, orgID) { + if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) { hs.NotFoundHandler(c) } cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { hs.NotFoundHandler(c) } @@ -246,7 +247,7 @@ func ChangeUserPassword(c *models.ReqContext, cmd models.ChangeUserPasswordComma return response.Error(500, "Failed to encode password", err) } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change user password", err) } @@ -269,7 +270,7 @@ func SetHelpFlag(c *models.ReqContext) response.Response { HelpFlags1: *bitmask, } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } @@ -282,7 +283,7 @@ func ClearHelpFlags(c *models.ReqContext) response.Response { HelpFlags1: models.HelpFlags1(0), } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } diff --git a/pkg/infra/serverlock/serverlock.go b/pkg/infra/serverlock/serverlock.go index 2eea513c851..87b19f75cd9 100644 --- a/pkg/infra/serverlock/serverlock.go +++ b/pkg/infra/serverlock/serverlock.go @@ -25,7 +25,7 @@ type ServerLockService struct { // LockAndExecute try to create a lock for this server and only executes the // `fn` function when successful. This should not be used at low internal. But services // that needs to be run once every ex 10m. -func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName string, maxInterval time.Duration, fn func()) error { +func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName string, maxInterval time.Duration, fn func(ctx context.Context)) error { // gets or creates a lockable row rowLock, err := sl.getOrCreate(ctx, actionName) if err != nil { @@ -47,7 +47,7 @@ func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName stri } if acquiredLock { - fn() + fn(ctx) } return nil diff --git a/pkg/infra/serverlock/serverlock_integration_test.go b/pkg/infra/serverlock/serverlock_integration_test.go index 33d34308db7..e3f21187988 100644 --- a/pkg/infra/serverlock/serverlock_integration_test.go +++ b/pkg/infra/serverlock/serverlock_integration_test.go @@ -15,7 +15,7 @@ func TestServerLok(t *testing.T) { sl := createTestableServerLock(t) counter := 0 - fn := func() { counter++ } + fn := func(context.Context) { counter++ } atInterval := time.Second * 1 ctx := context.Background() diff --git a/pkg/services/auth/token_cleanup.go b/pkg/services/auth/token_cleanup.go index ec6db34b30f..f211d3828b6 100644 --- a/pkg/services/auth/token_cleanup.go +++ b/pkg/services/auth/token_cleanup.go @@ -12,7 +12,7 @@ func (s *UserAuthTokenService) Run(ctx context.Context) error { maxInactiveLifetime := s.Cfg.LoginMaxInactiveLifetime maxLifetime := s.Cfg.LoginMaxLifetime - err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { + err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) { if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil { s.log.Error("An error occurred while deleting expired tokens", "err", err) } @@ -24,7 +24,7 @@ func (s *UserAuthTokenService) Run(ctx context.Context) error { for { select { case <-ticker.C: - err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { + err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) { if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil { s.log.Error("An error occurred while deleting expired tokens", "err", err) } diff --git a/pkg/services/cleanup/cleanup.go b/pkg/services/cleanup/cleanup.go index ade5c888590..c92b9053f81 100644 --- a/pkg/services/cleanup/cleanup.go +++ b/pkg/services/cleanup/cleanup.go @@ -53,7 +53,7 @@ func (srv *CleanUpService) Run(ctx context.Context) error { srv.expireOldUserInvites() srv.deleteStaleShortURLs() err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts", - time.Minute*10, func() { + time.Minute*10, func(context.Context) { srv.deleteOldLoginAttempts() }) if err != nil { diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index f8253df52cf..ee04ae1c48c 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -132,7 +132,7 @@ func (h *ContextHandler) Middleware(mContext *macaron.Context) { // update last seen every 5min if reqContext.ShouldUpdateLastSeenAt() { reqContext.Logger.Debug("Updating last user_seen_at", "user_id", reqContext.UserId) - if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil { + if err := bus.DispatchCtx(mContext.Req.Context(), &models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil { reqContext.Logger.Error("Failed to update last_seen_at", "error", err) } } diff --git a/pkg/services/sqlstore/dashboard_test.go b/pkg/services/sqlstore/dashboard_test.go index cecb1464f04..d729e4d36f6 100644 --- a/pkg/services/sqlstore/dashboard_test.go +++ b/pkg/services/sqlstore/dashboard_test.go @@ -595,7 +595,7 @@ func createUser(t *testing.T, sqlStore *SQLStore, name string, role string, isAd require.NoError(t, err) q1 := models.GetUserOrgListQuery{UserId: currentUser.Id} - err = GetUserOrgList(&q1) + err = GetUserOrgList(context.Background(), &q1) require.NoError(t, err) require.Equal(t, models.RoleType(role), q1.Result[0].Role) diff --git a/pkg/services/sqlstore/org_test.go b/pkg/services/sqlstore/org_test.go index 0fceb294259..c3ae7ed5212 100644 --- a/pkg/services/sqlstore/org_test.go +++ b/pkg/services/sqlstore/org_test.go @@ -87,9 +87,9 @@ func TestAccountDataAccess(t *testing.T) { q1 := models.GetUserOrgListQuery{UserId: ac1.Id} q2 := models.GetUserOrgListQuery{UserId: ac2.Id} - err = GetUserOrgList(&q1) + err = GetUserOrgList(context.Background(), &q1) require.NoError(t, err) - err = GetUserOrgList(&q2) + err = GetUserOrgList(context.Background(), &q2) require.NoError(t, err) require.Equal(t, q1.Result[0].OrgId, q2.Result[0].OrgId) @@ -149,7 +149,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Should be able to read user info projection", func(t *testing.T) { query := models.GetUserProfileQuery{UserId: ac1.Id} - err = GetUserProfile(&query) + err = sqlStore.GetUserProfile(context.Background(), &query) require.NoError(t, err) require.Equal(t, query.Result.Email, "ac1@test.com") @@ -158,7 +158,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Can search users", func(t *testing.T) { query := models.SearchUsersQuery{Query: ""} - err := SearchUsers(&query) + err := SearchUsers(context.Background(), &query) require.NoError(t, err) require.Equal(t, query.Result.Users[0].Email, "ac1@test.com") @@ -205,7 +205,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Can get user organizations", func(t *testing.T) { query := models.GetUserOrgListQuery{UserId: ac2.Id} - err := GetUserOrgList(&query) + err := GetUserOrgList(context.Background(), &query) require.NoError(t, err) require.Equal(t, len(query.Result), 2) @@ -247,7 +247,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Can set using org", func(t *testing.T) { cmd := models.SetUsingOrgCommand{UserId: ac2.Id, OrgId: ac1.OrgId} - err := SetUsingOrg(&cmd) + err := SetUsingOrg(context.Background(), &cmd) require.NoError(t, err) t.Run("SignedInUserQuery with a different org", func(t *testing.T) { diff --git a/pkg/services/sqlstore/stats_test.go b/pkg/services/sqlstore/stats_test.go index 88ec3e1b6fb..7e9b126e3cb 100644 --- a/pkg/services/sqlstore/stats_test.go +++ b/pkg/services/sqlstore/stats_test.go @@ -119,7 +119,7 @@ func populateDB(t *testing.T, sqlStore *SQLStore) { updateUserLastSeenAtCmd := &models.UpdateUserLastSeenAtCommand{ UserId: users[0].Id, } - err = UpdateUserLastSeenAt(updateUserLastSeenAtCmd) + err = UpdateUserLastSeenAt(context.Background(), updateUserLastSeenAtCmd) require.NoError(t, err) // force renewal of user stats diff --git a/pkg/services/sqlstore/user.go b/pkg/services/sqlstore/user.go index 343b5504e7c..9000d7abb11 100644 --- a/pkg/services/sqlstore/user.go +++ b/pkg/services/sqlstore/user.go @@ -19,19 +19,19 @@ func (ss *SQLStore) addUserQueryAndCommandHandlers() { ss.Bus.AddHandlerCtx(ss.GetSignedInUserWithCacheCtx) bus.AddHandlerCtx("sql", GetUserById) - bus.AddHandler("sql", UpdateUser) - bus.AddHandler("sql", ChangeUserPassword) - bus.AddHandler("sql", GetUserByLogin) - bus.AddHandler("sql", GetUserByEmail) - bus.AddHandler("sql", SetUsingOrg) - bus.AddHandler("sql", UpdateUserLastSeenAt) - bus.AddHandler("sql", GetUserProfile) - bus.AddHandler("sql", SearchUsers) - bus.AddHandler("sql", GetUserOrgList) - bus.AddHandler("sql", DisableUser) - bus.AddHandler("sql", BatchDisableUsers) - bus.AddHandler("sql", DeleteUser) - bus.AddHandler("sql", SetUserHelpFlag) + bus.AddHandlerCtx("sql", UpdateUser) + bus.AddHandlerCtx("sql", ChangeUserPassword) + bus.AddHandlerCtx("sql", ss.GetUserByLogin) + bus.AddHandlerCtx("sql", ss.GetUserByEmail) + bus.AddHandlerCtx("sql", SetUsingOrg) + bus.AddHandlerCtx("sql", UpdateUserLastSeenAt) + bus.AddHandlerCtx("sql", ss.GetUserProfile) + bus.AddHandlerCtx("sql", SearchUsers) + bus.AddHandlerCtx("sql", GetUserOrgList) + bus.AddHandlerCtx("sql", DisableUser) + bus.AddHandlerCtx("sql", BatchDisableUsers) + bus.AddHandlerCtx("sql", DeleteUser) + bus.AddHandlerCtx("sql", SetUserHelpFlag) } func getOrgIdForNewUser(sess *DBSession, cmd models.CreateUserCommand) (int64, error) { @@ -297,58 +297,62 @@ func GetUserById(ctx context.Context, query *models.GetUserByIdQuery) error { }) } -func GetUserByLogin(query *models.GetUserByLoginQuery) error { - if query.LoginOrEmail == "" { - return models.ErrUserNotFound - } +func (ss *SQLStore) GetUserByLogin(ctx context.Context, query *models.GetUserByLoginQuery) error { + return ss.WithDbSession(ctx, func(sess *DBSession) error { + if query.LoginOrEmail == "" { + return models.ErrUserNotFound + } - // Try and find the user by login first. - // It's not sufficient to assume that a LoginOrEmail with an "@" is an email. - user := &models.User{Login: query.LoginOrEmail} - has, err := x.Get(user) + // Try and find the user by login first. + // It's not sufficient to assume that a LoginOrEmail with an "@" is an email. + user := &models.User{Login: query.LoginOrEmail} + has, err := sess.Get(user) - if err != nil { - return err - } + if err != nil { + return err + } - if !has && strings.Contains(query.LoginOrEmail, "@") { - // If the user wasn't found, and it contains an "@" fallback to finding the - // user by email. - user = &models.User{Email: query.LoginOrEmail} - has, err = x.Get(user) - } + if !has && strings.Contains(query.LoginOrEmail, "@") { + // If the user wasn't found, and it contains an "@" fallback to finding the + // user by email. + user = &models.User{Email: query.LoginOrEmail} + has, err = sess.Get(user) + } - if err != nil { - return err - } else if !has { - return models.ErrUserNotFound - } + if err != nil { + return err + } else if !has { + return models.ErrUserNotFound + } - query.Result = user + query.Result = user - return nil + return nil + }) } -func GetUserByEmail(query *models.GetUserByEmailQuery) error { - if query.Email == "" { - return models.ErrUserNotFound - } +func (ss *SQLStore) GetUserByEmail(ctx context.Context, query *models.GetUserByEmailQuery) error { + return ss.WithDbSession(ctx, func(sess *DBSession) error { + if query.Email == "" { + return models.ErrUserNotFound + } - user := &models.User{Email: query.Email} - has, err := x.Get(user) + user := &models.User{Email: query.Email} + has, err := sess.Get(user) - if err != nil { - return err - } else if !has { - return models.ErrUserNotFound - } + if err != nil { + return err + } else if !has { + return models.ErrUserNotFound + } - query.Result = user + query.Result = user - return nil + return nil + }) } -func UpdateUser(cmd *models.UpdateUserCommand) error { +func UpdateUser(ctx context.Context, cmd *models.UpdateUserCommand) error { return inTransaction(func(sess *DBSession) error { user := models.User{ Name: cmd.Name, @@ -374,7 +378,7 @@ func UpdateUser(cmd *models.UpdateUserCommand) error { }) } -func ChangeUserPassword(cmd *models.ChangeUserPasswordCommand) error { +func ChangeUserPassword(ctx context.Context, cmd *models.ChangeUserPasswordCommand) error { return inTransaction(func(sess *DBSession) error { user := models.User{ Password: cmd.NewPassword, @@ -386,7 +390,7 @@ func ChangeUserPassword(cmd *models.ChangeUserPasswordCommand) error { }) } -func UpdateUserLastSeenAt(cmd *models.UpdateUserLastSeenAtCommand) error { +func UpdateUserLastSeenAt(ctx context.Context, cmd *models.UpdateUserLastSeenAtCommand) error { return inTransaction(func(sess *DBSession) error { user := models.User{ Id: cmd.UserId, @@ -398,9 +402,9 @@ func UpdateUserLastSeenAt(cmd *models.UpdateUserLastSeenAtCommand) error { }) } -func SetUsingOrg(cmd *models.SetUsingOrgCommand) error { +func SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error { getOrgsForUserCmd := &models.GetUserOrgListQuery{UserId: cmd.UserId} - if err := GetUserOrgList(getOrgsForUserCmd); err != nil { + if err := GetUserOrgList(ctx, getOrgsForUserCmd); err != nil { return err } @@ -429,30 +433,32 @@ func setUsingOrgInTransaction(sess *DBSession, userID int64, orgID int64) error return err } -func GetUserProfile(query *models.GetUserProfileQuery) error { - var user models.User - has, err := x.Id(query.UserId).Get(&user) +func (ss *SQLStore) GetUserProfile(ctx context.Context, query *models.GetUserProfileQuery) error { + return ss.WithDbSession(ctx, func(sess *DBSession) error { + var user models.User + has, err := sess.ID(query.UserId).Get(&user) + + if err != nil { + return err + } else if !has { + return models.ErrUserNotFound + } + + query.Result = models.UserProfileDTO{ + Id: user.Id, + Name: user.Name, + Email: user.Email, + Login: user.Login, + Theme: user.Theme, + IsGrafanaAdmin: user.IsAdmin, + IsDisabled: user.IsDisabled, + OrgId: user.OrgId, + UpdatedAt: user.Updated, + CreatedAt: user.Created, + } - if err != nil { return err - } else if !has { - return models.ErrUserNotFound - } - - query.Result = models.UserProfileDTO{ - Id: user.Id, - Name: user.Name, - Email: user.Email, - Login: user.Login, - Theme: user.Theme, - IsGrafanaAdmin: user.IsAdmin, - IsDisabled: user.IsDisabled, - OrgId: user.OrgId, - UpdatedAt: user.Updated, - CreatedAt: user.Created, - } - - return err + }) } type byOrgName []*models.UserOrgDTO @@ -476,7 +482,7 @@ func (o byOrgName) Less(i, j int) bool { return o[i].Name < o[j].Name } -func GetUserOrgList(query *models.GetUserOrgListQuery) error { +func GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error { query.Result = make([]*models.UserOrgDTO, 0) sess := x.Table("org_user") sess.Join("INNER", "org", "org_user.org_id=org.id") @@ -570,7 +576,7 @@ func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) er return err } -func SearchUsers(query *models.SearchUsersQuery) error { +func SearchUsers(ctx context.Context, query *models.SearchUsersQuery) error { query.Result = models.SearchUserQueryResult{ Users: make([]*models.UserSearchHitDTO, 0), } @@ -652,7 +658,7 @@ func SearchUsers(query *models.SearchUsersQuery) error { return err } -func DisableUser(cmd *models.DisableUserCommand) error { +func DisableUser(ctx context.Context, cmd *models.DisableUserCommand) error { user := models.User{} sess := x.Table("user") @@ -669,7 +675,7 @@ func DisableUser(cmd *models.DisableUserCommand) error { return err } -func BatchDisableUsers(cmd *models.BatchDisableUsersCommand) error { +func BatchDisableUsers(ctx context.Context, cmd *models.BatchDisableUsersCommand) error { return inTransaction(func(sess *DBSession) error { userIds := cmd.UserIds @@ -694,7 +700,7 @@ func BatchDisableUsers(cmd *models.BatchDisableUsersCommand) error { }) } -func DeleteUser(cmd *models.DeleteUserCommand) error { +func DeleteUser(ctx context.Context, cmd *models.DeleteUserCommand) error { return inTransaction(func(sess *DBSession) error { return deleteUserInTransaction(sess, cmd) }) @@ -757,7 +763,7 @@ func (ss *SQLStore) UpdateUserPermissions(userID int64, isAdmin bool) error { }) } -func SetUserHelpFlag(cmd *models.SetUserHelpFlagCommand) error { +func SetUserHelpFlag(ctx context.Context, cmd *models.SetUserHelpFlagCommand) error { return inTransaction(func(sess *DBSession) error { user := models.User{ Id: cmd.UserId, diff --git a/pkg/services/sqlstore/user_test.go b/pkg/services/sqlstore/user_test.go index 54723f72186..cf75547de1d 100644 --- a/pkg/services/sqlstore/user_test.go +++ b/pkg/services/sqlstore/user_test.go @@ -130,7 +130,7 @@ func TestUserDataAccess(t *testing.T) { // Return the first page of users and a total count query := models.SearchUsersQuery{Query: "", Page: 1, Limit: 3} - err := SearchUsers(&query) + err := SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 3) @@ -138,7 +138,7 @@ func TestUserDataAccess(t *testing.T) { // Return the second page of users and a total count query = models.SearchUsersQuery{Query: "", Page: 2, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 2) @@ -146,28 +146,28 @@ func TestUserDataAccess(t *testing.T) { // Return list of users matching query on user name query = models.SearchUsersQuery{Query: "use", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 3) require.EqualValues(t, query.Result.TotalCount, 5) query = models.SearchUsersQuery{Query: "ser1", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 1) require.EqualValues(t, query.Result.TotalCount, 1) query = models.SearchUsersQuery{Query: "USER1", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 1) require.EqualValues(t, query.Result.TotalCount, 1) query = models.SearchUsersQuery{Query: "idontexist", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 0) @@ -175,7 +175,7 @@ func TestUserDataAccess(t *testing.T) { // Return list of users matching query on email query = models.SearchUsersQuery{Query: "ser1@test.com", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 1) @@ -183,7 +183,7 @@ func TestUserDataAccess(t *testing.T) { // Return list of users matching query on login name query = models.SearchUsersQuery{Query: "loginuser1", Page: 1, Limit: 3} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 1) @@ -203,7 +203,7 @@ func TestUserDataAccess(t *testing.T) { isDisabled := false query := models.SearchUsersQuery{IsDisabled: &isDisabled} - err := SearchUsers(&query) + err := SearchUsers(context.Background(), &query) require.Nil(t, err) require.Len(t, query.Result.Users, 2) @@ -251,7 +251,7 @@ func TestUserDataAccess(t *testing.T) { require.Nil(t, err) // When the user is deleted - err = DeleteUser(&models.DeleteUserCommand{UserId: users[1].Id}) + err = DeleteUser(context.Background(), &models.DeleteUserCommand{UserId: users[1].Id}) require.Nil(t, err) query1 := &models.GetOrgUsersQuery{OrgId: users[0].OrgId} @@ -308,7 +308,7 @@ func TestUserDataAccess(t *testing.T) { require.Nil(t, err) require.NotNil(t, query3.Result) require.Equal(t, query3.OrgId, users[1].OrgId) - err = SetUsingOrg(&models.SetUsingOrgCommand{UserId: users[1].Id, OrgId: users[0].OrgId}) + err = SetUsingOrg(context.Background(), &models.SetUsingOrgCommand{UserId: users[1].Id, OrgId: users[0].OrgId}) require.Nil(t, err) query4 := &models.GetSignedInUserQuery{OrgId: 0, UserId: users[1].Id} err = ss.GetSignedInUserWithCacheCtx(context.Background(), query4) @@ -325,18 +325,18 @@ func TestUserDataAccess(t *testing.T) { IsDisabled: true, } - err = BatchDisableUsers(&disableCmd) + err = BatchDisableUsers(context.Background(), &disableCmd) require.Nil(t, err) isDisabled = true query5 := &models.SearchUsersQuery{IsDisabled: &isDisabled} - err = SearchUsers(query5) + err = SearchUsers(context.Background(), query5) require.Nil(t, err) require.EqualValues(t, query5.Result.TotalCount, 5) // the user is deleted - err = DeleteUser(&models.DeleteUserCommand{UserId: users[1].Id}) + err = DeleteUser(context.Background(), &models.DeleteUserCommand{UserId: users[1].Id}) require.Nil(t, err) // delete connected org users and permissions @@ -378,12 +378,12 @@ func TestUserDataAccess(t *testing.T) { IsDisabled: false, } - err := BatchDisableUsers(&disableCmd) + err := BatchDisableUsers(context.Background(), &disableCmd) require.Nil(t, err) isDisabled := false query := &models.SearchUsersQuery{IsDisabled: &isDisabled} - err = SearchUsers(query) + err = SearchUsers(context.Background(), query) require.Nil(t, err) require.EqualValues(t, query.Result.TotalCount, 5) @@ -410,11 +410,11 @@ func TestUserDataAccess(t *testing.T) { IsDisabled: true, } - err := BatchDisableUsers(&disableCmd) + err := BatchDisableUsers(context.Background(), &disableCmd) require.Nil(t, err) query := models.SearchUsersQuery{} - err = SearchUsers(&query) + err = SearchUsers(context.Background(), &query) require.Nil(t, err) require.EqualValues(t, query.Result.TotalCount, 5)