Chore: Add context to user (#39649)

* Add context to user

* Add context for enterprise

* Add context for UpdateUserLastSeenAtCommand

* Remove xorm
This commit is contained in:
idafurjes 2021-10-04 15:46:09 +02:00 committed by GitHub
parent 42d7c32759
commit f4f0d74838
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 152 additions and 145 deletions

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@ -13,18 +14,18 @@ import (
// GET /api/user (current authenticated user) // GET /api/user (current authenticated user)
func GetSignedInUser(c *models.ReqContext) response.Response { func GetSignedInUser(c *models.ReqContext) response.Response {
return getUserUserProfile(c.UserId) return getUserUserProfile(c.Req.Context(), c.UserId)
} }
// GET /api/users/:id // GET /api/users/:id
func GetUserByID(c *models.ReqContext) response.Response { func GetUserByID(c *models.ReqContext) response.Response {
return getUserUserProfile(c.ParamsInt64(":id")) return getUserUserProfile(c.Req.Context(), c.ParamsInt64(":id"))
} }
func getUserUserProfile(userID int64) response.Response { func getUserUserProfile(ctx context.Context, userID int64) response.Response {
query := models.GetUserProfileQuery{UserId: userID} query := models.GetUserProfileQuery{UserId: userID}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }
@ -33,7 +34,7 @@ func getUserUserProfile(userID int64) response.Response {
getAuthQuery := models.GetAuthInfoQuery{UserId: userID} getAuthQuery := models.GetAuthInfoQuery{UserId: userID}
query.Result.AuthLabels = []string{} query.Result.AuthLabels = []string{}
if err := bus.Dispatch(&getAuthQuery); err == nil { if err := bus.DispatchCtx(ctx, &getAuthQuery); err == nil {
authLabel := GetAuthProviderLabel(getAuthQuery.Result.AuthModule) authLabel := GetAuthProviderLabel(getAuthQuery.Result.AuthModule)
query.Result.AuthLabels = append(query.Result.AuthLabels, authLabel) query.Result.AuthLabels = append(query.Result.AuthLabels, authLabel)
query.Result.IsExternal = true query.Result.IsExternal = true
@ -47,7 +48,7 @@ func getUserUserProfile(userID int64) response.Response {
// GET /api/users/lookup // GET /api/users/lookup
func GetUserByLoginOrEmail(c *models.ReqContext) response.Response { func GetUserByLoginOrEmail(c *models.ReqContext) response.Response {
query := models.GetUserByLoginQuery{LoginOrEmail: c.Query("loginOrEmail")} query := models.GetUserByLoginQuery{LoginOrEmail: c.Query("loginOrEmail")}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &query); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }
@ -79,13 +80,13 @@ func UpdateSignedInUser(c *models.ReqContext, cmd models.UpdateUserCommand) resp
} }
} }
cmd.UserId = c.UserId cmd.UserId = c.UserId
return handleUpdateUser(cmd) return handleUpdateUser(c.Req.Context(), cmd)
} }
// POST /api/users/:id // POST /api/users/:id
func UpdateUser(c *models.ReqContext, cmd models.UpdateUserCommand) response.Response { func UpdateUser(c *models.ReqContext, cmd models.UpdateUserCommand) response.Response {
cmd.UserId = c.ParamsInt64(":id") cmd.UserId = c.ParamsInt64(":id")
return handleUpdateUser(cmd) return handleUpdateUser(c.Req.Context(), cmd)
} }
// POST /api/users/:id/using/:orgId // POST /api/users/:id/using/:orgId
@ -93,20 +94,20 @@ func UpdateUserActiveOrg(c *models.ReqContext) response.Response {
userID := c.ParamsInt64(":id") userID := c.ParamsInt64(":id")
orgID := c.ParamsInt64(":orgId") orgID := c.ParamsInt64(":orgId")
if !validateUsingOrg(userID, orgID) { if !validateUsingOrg(c.Req.Context(), userID, orgID) {
return response.Error(401, "Not a valid organization", nil) return response.Error(401, "Not a valid organization", nil)
} }
cmd := models.SetUsingOrgCommand{UserId: userID, OrgId: orgID} cmd := models.SetUsingOrgCommand{UserId: userID, OrgId: orgID}
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to change active organization", err) return response.Error(500, "Failed to change active organization", err)
} }
return response.Success("Active organization changed") return response.Success("Active organization changed")
} }
func handleUpdateUser(cmd models.UpdateUserCommand) response.Response { func handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) response.Response {
if len(cmd.Login) == 0 { if len(cmd.Login) == 0 {
cmd.Login = cmd.Email cmd.Login = cmd.Email
if len(cmd.Login) == 0 { if len(cmd.Login) == 0 {
@ -114,7 +115,7 @@ func handleUpdateUser(cmd models.UpdateUserCommand) response.Response {
} }
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(ctx, &cmd); err != nil {
return response.Error(500, "Failed to update user", err) return response.Error(500, "Failed to update user", err)
} }
@ -123,23 +124,23 @@ func handleUpdateUser(cmd models.UpdateUserCommand) response.Response {
// GET /api/user/orgs // GET /api/user/orgs
func GetSignedInUserOrgList(c *models.ReqContext) response.Response { func GetSignedInUserOrgList(c *models.ReqContext) response.Response {
return getUserOrgList(c.UserId) return getUserOrgList(c.Req.Context(), c.UserId)
} }
// GET /api/user/teams // GET /api/user/teams
func GetSignedInUserTeamList(c *models.ReqContext) response.Response { func GetSignedInUserTeamList(c *models.ReqContext) response.Response {
return getUserTeamList(c.OrgId, c.UserId) return getUserTeamList(c.Req.Context(), c.OrgId, c.UserId)
} }
// GET /api/users/:id/teams // GET /api/users/:id/teams
func GetUserTeams(c *models.ReqContext) response.Response { func GetUserTeams(c *models.ReqContext) response.Response {
return getUserTeamList(c.OrgId, c.ParamsInt64(":id")) return getUserTeamList(c.Req.Context(), c.OrgId, c.ParamsInt64(":id"))
} }
func getUserTeamList(orgID int64, userID int64) response.Response { func getUserTeamList(ctx context.Context, orgID int64, userID int64) response.Response {
query := models.GetTeamsByUserQuery{OrgId: orgID, UserId: userID} query := models.GetTeamsByUserQuery{OrgId: orgID, UserId: userID}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
return response.Error(500, "Failed to get user teams", err) return response.Error(500, "Failed to get user teams", err)
} }
@ -151,23 +152,23 @@ func getUserTeamList(orgID int64, userID int64) response.Response {
// GET /api/users/:id/orgs // GET /api/users/:id/orgs
func GetUserOrgList(c *models.ReqContext) response.Response { func GetUserOrgList(c *models.ReqContext) response.Response {
return getUserOrgList(c.ParamsInt64(":id")) return getUserOrgList(c.Req.Context(), c.ParamsInt64(":id"))
} }
func getUserOrgList(userID int64) response.Response { func getUserOrgList(ctx context.Context, userID int64) response.Response {
query := models.GetUserOrgListQuery{UserId: userID} query := models.GetUserOrgListQuery{UserId: userID}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
return response.Error(500, "Failed to get user organizations", err) return response.Error(500, "Failed to get user organizations", err)
} }
return response.JSON(200, query.Result) return response.JSON(200, query.Result)
} }
func validateUsingOrg(userID int64, orgID int64) bool { func validateUsingOrg(ctx context.Context, userID int64, orgID int64) bool {
query := models.GetUserOrgListQuery{UserId: userID} query := models.GetUserOrgListQuery{UserId: userID}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
return false return false
} }
@ -186,13 +187,13 @@ func validateUsingOrg(userID int64, orgID int64) bool {
func UserSetUsingOrg(c *models.ReqContext) response.Response { func UserSetUsingOrg(c *models.ReqContext) response.Response {
orgID := c.ParamsInt64(":id") orgID := c.ParamsInt64(":id")
if !validateUsingOrg(c.UserId, orgID) { if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) {
return response.Error(401, "Not a valid organization", nil) return response.Error(401, "Not a valid organization", nil)
} }
cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID}
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to change active organization", err) return response.Error(500, "Failed to change active organization", err)
} }
@ -203,13 +204,13 @@ func UserSetUsingOrg(c *models.ReqContext) response.Response {
func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *models.ReqContext) { func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *models.ReqContext) {
orgID := c.ParamsInt64(":id") orgID := c.ParamsInt64(":id")
if !validateUsingOrg(c.UserId, orgID) { if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) {
hs.NotFoundHandler(c) hs.NotFoundHandler(c)
} }
cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID}
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
hs.NotFoundHandler(c) hs.NotFoundHandler(c)
} }
@ -246,7 +247,7 @@ func ChangeUserPassword(c *models.ReqContext, cmd models.ChangeUserPasswordComma
return response.Error(500, "Failed to encode password", err) return response.Error(500, "Failed to encode password", err)
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to change user password", err) return response.Error(500, "Failed to change user password", err)
} }
@ -269,7 +270,7 @@ func SetHelpFlag(c *models.ReqContext) response.Response {
HelpFlags1: *bitmask, HelpFlags1: *bitmask,
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to update help flag", err) return response.Error(500, "Failed to update help flag", err)
} }
@ -282,7 +283,7 @@ func ClearHelpFlags(c *models.ReqContext) response.Response {
HelpFlags1: models.HelpFlags1(0), HelpFlags1: models.HelpFlags1(0),
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &cmd); err != nil {
return response.Error(500, "Failed to update help flag", err) return response.Error(500, "Failed to update help flag", err)
} }

View File

@ -25,7 +25,7 @@ type ServerLockService struct {
// LockAndExecute try to create a lock for this server and only executes the // LockAndExecute try to create a lock for this server and only executes the
// `fn` function when successful. This should not be used at low internal. But services // `fn` function when successful. This should not be used at low internal. But services
// that needs to be run once every ex 10m. // that needs to be run once every ex 10m.
func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName string, maxInterval time.Duration, fn func()) error { func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName string, maxInterval time.Duration, fn func(ctx context.Context)) error {
// gets or creates a lockable row // gets or creates a lockable row
rowLock, err := sl.getOrCreate(ctx, actionName) rowLock, err := sl.getOrCreate(ctx, actionName)
if err != nil { if err != nil {
@ -47,7 +47,7 @@ func (sl *ServerLockService) LockAndExecute(ctx context.Context, actionName stri
} }
if acquiredLock { if acquiredLock {
fn() fn(ctx)
} }
return nil return nil

View File

@ -15,7 +15,7 @@ func TestServerLok(t *testing.T) {
sl := createTestableServerLock(t) sl := createTestableServerLock(t)
counter := 0 counter := 0
fn := func() { counter++ } fn := func(context.Context) { counter++ }
atInterval := time.Second * 1 atInterval := time.Second * 1
ctx := context.Background() ctx := context.Background()

View File

@ -12,7 +12,7 @@ func (s *UserAuthTokenService) Run(ctx context.Context) error {
maxInactiveLifetime := s.Cfg.LoginMaxInactiveLifetime maxInactiveLifetime := s.Cfg.LoginMaxInactiveLifetime
maxLifetime := s.Cfg.LoginMaxLifetime maxLifetime := s.Cfg.LoginMaxLifetime
err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil { if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil {
s.log.Error("An error occurred while deleting expired tokens", "err", err) s.log.Error("An error occurred while deleting expired tokens", "err", err)
} }
@ -24,7 +24,7 @@ func (s *UserAuthTokenService) Run(ctx context.Context) error {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func() { err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil { if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil {
s.log.Error("An error occurred while deleting expired tokens", "err", err) s.log.Error("An error occurred while deleting expired tokens", "err", err)
} }

View File

@ -53,7 +53,7 @@ func (srv *CleanUpService) Run(ctx context.Context) error {
srv.expireOldUserInvites() srv.expireOldUserInvites()
srv.deleteStaleShortURLs() srv.deleteStaleShortURLs()
err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts", err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts",
time.Minute*10, func() { time.Minute*10, func(context.Context) {
srv.deleteOldLoginAttempts() srv.deleteOldLoginAttempts()
}) })
if err != nil { if err != nil {

View File

@ -132,7 +132,7 @@ func (h *ContextHandler) Middleware(mContext *macaron.Context) {
// update last seen every 5min // update last seen every 5min
if reqContext.ShouldUpdateLastSeenAt() { if reqContext.ShouldUpdateLastSeenAt() {
reqContext.Logger.Debug("Updating last user_seen_at", "user_id", reqContext.UserId) reqContext.Logger.Debug("Updating last user_seen_at", "user_id", reqContext.UserId)
if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil { if err := bus.DispatchCtx(mContext.Req.Context(), &models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil {
reqContext.Logger.Error("Failed to update last_seen_at", "error", err) reqContext.Logger.Error("Failed to update last_seen_at", "error", err)
} }
} }

View File

@ -595,7 +595,7 @@ func createUser(t *testing.T, sqlStore *SQLStore, name string, role string, isAd
require.NoError(t, err) require.NoError(t, err)
q1 := models.GetUserOrgListQuery{UserId: currentUser.Id} q1 := models.GetUserOrgListQuery{UserId: currentUser.Id}
err = GetUserOrgList(&q1) err = GetUserOrgList(context.Background(), &q1)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, models.RoleType(role), q1.Result[0].Role) require.Equal(t, models.RoleType(role), q1.Result[0].Role)

View File

@ -87,9 +87,9 @@ func TestAccountDataAccess(t *testing.T) {
q1 := models.GetUserOrgListQuery{UserId: ac1.Id} q1 := models.GetUserOrgListQuery{UserId: ac1.Id}
q2 := models.GetUserOrgListQuery{UserId: ac2.Id} q2 := models.GetUserOrgListQuery{UserId: ac2.Id}
err = GetUserOrgList(&q1) err = GetUserOrgList(context.Background(), &q1)
require.NoError(t, err) require.NoError(t, err)
err = GetUserOrgList(&q2) err = GetUserOrgList(context.Background(), &q2)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, q1.Result[0].OrgId, q2.Result[0].OrgId) require.Equal(t, q1.Result[0].OrgId, q2.Result[0].OrgId)
@ -149,7 +149,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Should be able to read user info projection", func(t *testing.T) { t.Run("Should be able to read user info projection", func(t *testing.T) {
query := models.GetUserProfileQuery{UserId: ac1.Id} query := models.GetUserProfileQuery{UserId: ac1.Id}
err = GetUserProfile(&query) err = sqlStore.GetUserProfile(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, query.Result.Email, "ac1@test.com") require.Equal(t, query.Result.Email, "ac1@test.com")
@ -158,7 +158,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Can search users", func(t *testing.T) { t.Run("Can search users", func(t *testing.T) {
query := models.SearchUsersQuery{Query: ""} query := models.SearchUsersQuery{Query: ""}
err := SearchUsers(&query) err := SearchUsers(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, query.Result.Users[0].Email, "ac1@test.com") require.Equal(t, query.Result.Users[0].Email, "ac1@test.com")
@ -205,7 +205,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Can get user organizations", func(t *testing.T) { t.Run("Can get user organizations", func(t *testing.T) {
query := models.GetUserOrgListQuery{UserId: ac2.Id} query := models.GetUserOrgListQuery{UserId: ac2.Id}
err := GetUserOrgList(&query) err := GetUserOrgList(context.Background(), &query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(query.Result), 2) require.Equal(t, len(query.Result), 2)
@ -247,7 +247,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Can set using org", func(t *testing.T) { t.Run("Can set using org", func(t *testing.T) {
cmd := models.SetUsingOrgCommand{UserId: ac2.Id, OrgId: ac1.OrgId} cmd := models.SetUsingOrgCommand{UserId: ac2.Id, OrgId: ac1.OrgId}
err := SetUsingOrg(&cmd) err := SetUsingOrg(context.Background(), &cmd)
require.NoError(t, err) require.NoError(t, err)
t.Run("SignedInUserQuery with a different org", func(t *testing.T) { t.Run("SignedInUserQuery with a different org", func(t *testing.T) {

View File

@ -119,7 +119,7 @@ func populateDB(t *testing.T, sqlStore *SQLStore) {
updateUserLastSeenAtCmd := &models.UpdateUserLastSeenAtCommand{ updateUserLastSeenAtCmd := &models.UpdateUserLastSeenAtCommand{
UserId: users[0].Id, UserId: users[0].Id,
} }
err = UpdateUserLastSeenAt(updateUserLastSeenAtCmd) err = UpdateUserLastSeenAt(context.Background(), updateUserLastSeenAtCmd)
require.NoError(t, err) require.NoError(t, err)
// force renewal of user stats // force renewal of user stats

View File

@ -19,19 +19,19 @@ func (ss *SQLStore) addUserQueryAndCommandHandlers() {
ss.Bus.AddHandlerCtx(ss.GetSignedInUserWithCacheCtx) ss.Bus.AddHandlerCtx(ss.GetSignedInUserWithCacheCtx)
bus.AddHandlerCtx("sql", GetUserById) bus.AddHandlerCtx("sql", GetUserById)
bus.AddHandler("sql", UpdateUser) bus.AddHandlerCtx("sql", UpdateUser)
bus.AddHandler("sql", ChangeUserPassword) bus.AddHandlerCtx("sql", ChangeUserPassword)
bus.AddHandler("sql", GetUserByLogin) bus.AddHandlerCtx("sql", ss.GetUserByLogin)
bus.AddHandler("sql", GetUserByEmail) bus.AddHandlerCtx("sql", ss.GetUserByEmail)
bus.AddHandler("sql", SetUsingOrg) bus.AddHandlerCtx("sql", SetUsingOrg)
bus.AddHandler("sql", UpdateUserLastSeenAt) bus.AddHandlerCtx("sql", UpdateUserLastSeenAt)
bus.AddHandler("sql", GetUserProfile) bus.AddHandlerCtx("sql", ss.GetUserProfile)
bus.AddHandler("sql", SearchUsers) bus.AddHandlerCtx("sql", SearchUsers)
bus.AddHandler("sql", GetUserOrgList) bus.AddHandlerCtx("sql", GetUserOrgList)
bus.AddHandler("sql", DisableUser) bus.AddHandlerCtx("sql", DisableUser)
bus.AddHandler("sql", BatchDisableUsers) bus.AddHandlerCtx("sql", BatchDisableUsers)
bus.AddHandler("sql", DeleteUser) bus.AddHandlerCtx("sql", DeleteUser)
bus.AddHandler("sql", SetUserHelpFlag) bus.AddHandlerCtx("sql", SetUserHelpFlag)
} }
func getOrgIdForNewUser(sess *DBSession, cmd models.CreateUserCommand) (int64, error) { func getOrgIdForNewUser(sess *DBSession, cmd models.CreateUserCommand) (int64, error) {
@ -297,7 +297,8 @@ func GetUserById(ctx context.Context, query *models.GetUserByIdQuery) error {
}) })
} }
func GetUserByLogin(query *models.GetUserByLoginQuery) error { func (ss *SQLStore) GetUserByLogin(ctx context.Context, query *models.GetUserByLoginQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
if query.LoginOrEmail == "" { if query.LoginOrEmail == "" {
return models.ErrUserNotFound return models.ErrUserNotFound
} }
@ -305,7 +306,7 @@ func GetUserByLogin(query *models.GetUserByLoginQuery) error {
// Try and find the user by login first. // Try and find the user by login first.
// It's not sufficient to assume that a LoginOrEmail with an "@" is an email. // It's not sufficient to assume that a LoginOrEmail with an "@" is an email.
user := &models.User{Login: query.LoginOrEmail} user := &models.User{Login: query.LoginOrEmail}
has, err := x.Get(user) has, err := sess.Get(user)
if err != nil { if err != nil {
return err return err
@ -315,7 +316,7 @@ func GetUserByLogin(query *models.GetUserByLoginQuery) error {
// If the user wasn't found, and it contains an "@" fallback to finding the // If the user wasn't found, and it contains an "@" fallback to finding the
// user by email. // user by email.
user = &models.User{Email: query.LoginOrEmail} user = &models.User{Email: query.LoginOrEmail}
has, err = x.Get(user) has, err = sess.Get(user)
} }
if err != nil { if err != nil {
@ -327,15 +328,17 @@ func GetUserByLogin(query *models.GetUserByLoginQuery) error {
query.Result = user query.Result = user
return nil return nil
})
} }
func GetUserByEmail(query *models.GetUserByEmailQuery) error { func (ss *SQLStore) GetUserByEmail(ctx context.Context, query *models.GetUserByEmailQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
if query.Email == "" { if query.Email == "" {
return models.ErrUserNotFound return models.ErrUserNotFound
} }
user := &models.User{Email: query.Email} user := &models.User{Email: query.Email}
has, err := x.Get(user) has, err := sess.Get(user)
if err != nil { if err != nil {
return err return err
@ -346,9 +349,10 @@ func GetUserByEmail(query *models.GetUserByEmailQuery) error {
query.Result = user query.Result = user
return nil return nil
})
} }
func UpdateUser(cmd *models.UpdateUserCommand) error { func UpdateUser(ctx context.Context, cmd *models.UpdateUserCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
user := models.User{ user := models.User{
Name: cmd.Name, Name: cmd.Name,
@ -374,7 +378,7 @@ func UpdateUser(cmd *models.UpdateUserCommand) error {
}) })
} }
func ChangeUserPassword(cmd *models.ChangeUserPasswordCommand) error { func ChangeUserPassword(ctx context.Context, cmd *models.ChangeUserPasswordCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
user := models.User{ user := models.User{
Password: cmd.NewPassword, Password: cmd.NewPassword,
@ -386,7 +390,7 @@ func ChangeUserPassword(cmd *models.ChangeUserPasswordCommand) error {
}) })
} }
func UpdateUserLastSeenAt(cmd *models.UpdateUserLastSeenAtCommand) error { func UpdateUserLastSeenAt(ctx context.Context, cmd *models.UpdateUserLastSeenAtCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
user := models.User{ user := models.User{
Id: cmd.UserId, Id: cmd.UserId,
@ -398,9 +402,9 @@ func UpdateUserLastSeenAt(cmd *models.UpdateUserLastSeenAtCommand) error {
}) })
} }
func SetUsingOrg(cmd *models.SetUsingOrgCommand) error { func SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error {
getOrgsForUserCmd := &models.GetUserOrgListQuery{UserId: cmd.UserId} getOrgsForUserCmd := &models.GetUserOrgListQuery{UserId: cmd.UserId}
if err := GetUserOrgList(getOrgsForUserCmd); err != nil { if err := GetUserOrgList(ctx, getOrgsForUserCmd); err != nil {
return err return err
} }
@ -429,9 +433,10 @@ func setUsingOrgInTransaction(sess *DBSession, userID int64, orgID int64) error
return err return err
} }
func GetUserProfile(query *models.GetUserProfileQuery) error { func (ss *SQLStore) GetUserProfile(ctx context.Context, query *models.GetUserProfileQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
var user models.User var user models.User
has, err := x.Id(query.UserId).Get(&user) has, err := sess.ID(query.UserId).Get(&user)
if err != nil { if err != nil {
return err return err
@ -453,6 +458,7 @@ func GetUserProfile(query *models.GetUserProfileQuery) error {
} }
return err return err
})
} }
type byOrgName []*models.UserOrgDTO type byOrgName []*models.UserOrgDTO
@ -476,7 +482,7 @@ func (o byOrgName) Less(i, j int) bool {
return o[i].Name < o[j].Name return o[i].Name < o[j].Name
} }
func GetUserOrgList(query *models.GetUserOrgListQuery) error { func GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error {
query.Result = make([]*models.UserOrgDTO, 0) query.Result = make([]*models.UserOrgDTO, 0)
sess := x.Table("org_user") sess := x.Table("org_user")
sess.Join("INNER", "org", "org_user.org_id=org.id") sess.Join("INNER", "org", "org_user.org_id=org.id")
@ -570,7 +576,7 @@ func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) er
return err return err
} }
func SearchUsers(query *models.SearchUsersQuery) error { func SearchUsers(ctx context.Context, query *models.SearchUsersQuery) error {
query.Result = models.SearchUserQueryResult{ query.Result = models.SearchUserQueryResult{
Users: make([]*models.UserSearchHitDTO, 0), Users: make([]*models.UserSearchHitDTO, 0),
} }
@ -652,7 +658,7 @@ func SearchUsers(query *models.SearchUsersQuery) error {
return err return err
} }
func DisableUser(cmd *models.DisableUserCommand) error { func DisableUser(ctx context.Context, cmd *models.DisableUserCommand) error {
user := models.User{} user := models.User{}
sess := x.Table("user") sess := x.Table("user")
@ -669,7 +675,7 @@ func DisableUser(cmd *models.DisableUserCommand) error {
return err return err
} }
func BatchDisableUsers(cmd *models.BatchDisableUsersCommand) error { func BatchDisableUsers(ctx context.Context, cmd *models.BatchDisableUsersCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
userIds := cmd.UserIds userIds := cmd.UserIds
@ -694,7 +700,7 @@ func BatchDisableUsers(cmd *models.BatchDisableUsersCommand) error {
}) })
} }
func DeleteUser(cmd *models.DeleteUserCommand) error { func DeleteUser(ctx context.Context, cmd *models.DeleteUserCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
return deleteUserInTransaction(sess, cmd) return deleteUserInTransaction(sess, cmd)
}) })
@ -757,7 +763,7 @@ func (ss *SQLStore) UpdateUserPermissions(userID int64, isAdmin bool) error {
}) })
} }
func SetUserHelpFlag(cmd *models.SetUserHelpFlagCommand) error { func SetUserHelpFlag(ctx context.Context, cmd *models.SetUserHelpFlagCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
user := models.User{ user := models.User{
Id: cmd.UserId, Id: cmd.UserId,

View File

@ -130,7 +130,7 @@ func TestUserDataAccess(t *testing.T) {
// Return the first page of users and a total count // Return the first page of users and a total count
query := models.SearchUsersQuery{Query: "", Page: 1, Limit: 3} query := models.SearchUsersQuery{Query: "", Page: 1, Limit: 3}
err := SearchUsers(&query) err := SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 3) require.Len(t, query.Result.Users, 3)
@ -138,7 +138,7 @@ func TestUserDataAccess(t *testing.T) {
// Return the second page of users and a total count // Return the second page of users and a total count
query = models.SearchUsersQuery{Query: "", Page: 2, Limit: 3} query = models.SearchUsersQuery{Query: "", Page: 2, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 2) require.Len(t, query.Result.Users, 2)
@ -146,28 +146,28 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on user name // Return list of users matching query on user name
query = models.SearchUsersQuery{Query: "use", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "use", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 3) require.Len(t, query.Result.Users, 3)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)
query = models.SearchUsersQuery{Query: "ser1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "ser1", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
require.EqualValues(t, query.Result.TotalCount, 1) require.EqualValues(t, query.Result.TotalCount, 1)
query = models.SearchUsersQuery{Query: "USER1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "USER1", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
require.EqualValues(t, query.Result.TotalCount, 1) require.EqualValues(t, query.Result.TotalCount, 1)
query = models.SearchUsersQuery{Query: "idontexist", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "idontexist", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 0) require.Len(t, query.Result.Users, 0)
@ -175,7 +175,7 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on email // Return list of users matching query on email
query = models.SearchUsersQuery{Query: "ser1@test.com", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "ser1@test.com", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
@ -183,7 +183,7 @@ func TestUserDataAccess(t *testing.T) {
// Return list of users matching query on login name // Return list of users matching query on login name
query = models.SearchUsersQuery{Query: "loginuser1", Page: 1, Limit: 3} query = models.SearchUsersQuery{Query: "loginuser1", Page: 1, Limit: 3}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 1) require.Len(t, query.Result.Users, 1)
@ -203,7 +203,7 @@ func TestUserDataAccess(t *testing.T) {
isDisabled := false isDisabled := false
query := models.SearchUsersQuery{IsDisabled: &isDisabled} query := models.SearchUsersQuery{IsDisabled: &isDisabled}
err := SearchUsers(&query) err := SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Len(t, query.Result.Users, 2) require.Len(t, query.Result.Users, 2)
@ -251,7 +251,7 @@ func TestUserDataAccess(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
// When the user is deleted // When the user is deleted
err = DeleteUser(&models.DeleteUserCommand{UserId: users[1].Id}) err = DeleteUser(context.Background(), &models.DeleteUserCommand{UserId: users[1].Id})
require.Nil(t, err) require.Nil(t, err)
query1 := &models.GetOrgUsersQuery{OrgId: users[0].OrgId} query1 := &models.GetOrgUsersQuery{OrgId: users[0].OrgId}
@ -308,7 +308,7 @@ func TestUserDataAccess(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.NotNil(t, query3.Result) require.NotNil(t, query3.Result)
require.Equal(t, query3.OrgId, users[1].OrgId) require.Equal(t, query3.OrgId, users[1].OrgId)
err = SetUsingOrg(&models.SetUsingOrgCommand{UserId: users[1].Id, OrgId: users[0].OrgId}) err = SetUsingOrg(context.Background(), &models.SetUsingOrgCommand{UserId: users[1].Id, OrgId: users[0].OrgId})
require.Nil(t, err) require.Nil(t, err)
query4 := &models.GetSignedInUserQuery{OrgId: 0, UserId: users[1].Id} query4 := &models.GetSignedInUserQuery{OrgId: 0, UserId: users[1].Id}
err = ss.GetSignedInUserWithCacheCtx(context.Background(), query4) err = ss.GetSignedInUserWithCacheCtx(context.Background(), query4)
@ -325,18 +325,18 @@ func TestUserDataAccess(t *testing.T) {
IsDisabled: true, IsDisabled: true,
} }
err = BatchDisableUsers(&disableCmd) err = BatchDisableUsers(context.Background(), &disableCmd)
require.Nil(t, err) require.Nil(t, err)
isDisabled = true isDisabled = true
query5 := &models.SearchUsersQuery{IsDisabled: &isDisabled} query5 := &models.SearchUsersQuery{IsDisabled: &isDisabled}
err = SearchUsers(query5) err = SearchUsers(context.Background(), query5)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query5.Result.TotalCount, 5) require.EqualValues(t, query5.Result.TotalCount, 5)
// the user is deleted // the user is deleted
err = DeleteUser(&models.DeleteUserCommand{UserId: users[1].Id}) err = DeleteUser(context.Background(), &models.DeleteUserCommand{UserId: users[1].Id})
require.Nil(t, err) require.Nil(t, err)
// delete connected org users and permissions // delete connected org users and permissions
@ -378,12 +378,12 @@ func TestUserDataAccess(t *testing.T) {
IsDisabled: false, IsDisabled: false,
} }
err := BatchDisableUsers(&disableCmd) err := BatchDisableUsers(context.Background(), &disableCmd)
require.Nil(t, err) require.Nil(t, err)
isDisabled := false isDisabled := false
query := &models.SearchUsersQuery{IsDisabled: &isDisabled} query := &models.SearchUsersQuery{IsDisabled: &isDisabled}
err = SearchUsers(query) err = SearchUsers(context.Background(), query)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)
@ -410,11 +410,11 @@ func TestUserDataAccess(t *testing.T) {
IsDisabled: true, IsDisabled: true,
} }
err := BatchDisableUsers(&disableCmd) err := BatchDisableUsers(context.Background(), &disableCmd)
require.Nil(t, err) require.Nil(t, err)
query := models.SearchUsersQuery{} query := models.SearchUsersQuery{}
err = SearchUsers(&query) err = SearchUsers(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.EqualValues(t, query.Result.TotalCount, 5) require.EqualValues(t, query.Result.TotalCount, 5)