diff --git a/pkg/api/admin_users.go b/pkg/api/admin_users.go index f9bb2150964..1f9ee5dfc36 100644 --- a/pkg/api/admin_users.go +++ b/pkg/api/admin_users.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "fmt" @@ -123,7 +124,7 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response { // External users shouldn't be disabled from API authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} - if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { + if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { return response.Error(500, "Could not disable external user", nil) } @@ -149,7 +150,7 @@ func AdminEnableUser(c *models.ReqContext) response.Response { // External users shouldn't be disabled from API authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} - if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { + if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { return response.Error(500, "Could not enable external user", nil) } diff --git a/pkg/api/ldap_debug.go b/pkg/api/ldap_debug.go index 4a0dc7537bd..bba65c70257 100644 --- a/pkg/api/ldap_debug.go +++ b/pkg/api/ldap_debug.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "fmt" "net/http" @@ -176,7 +177,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.Id, AuthModule: models.AuthModuleLDAP} - if err := bus.Dispatch(authModuleQuery); err != nil { // validate the userId comes from LDAP + if err := bus.DispatchCtx(context.TODO(), authModuleQuery); err != nil { // validate the userId comes from LDAP if errors.Is(err, models.ErrUserNotFound) { return response.Error(404, models.ErrUserNotFound.Error(), nil) } diff --git a/pkg/api/login.go b/pkg/api/login.go index 2089838e1c6..c9ce5e7a84e 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -208,7 +208,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response { Cfg: hs.Cfg, } - err := bus.Dispatch(authQuery) + err := bus.DispatchCtx(c.Req.Context(), authQuery) authModule = authQuery.AuthModule if err != nil { resp = response.Error(401, "Invalid username or password", err) diff --git a/pkg/api/org_invite.go b/pkg/api/org_invite.go index 51b567d6464..f44a835bf4c 100644 --- a/pkg/api/org_invite.go +++ b/pkg/api/org_invite.go @@ -37,7 +37,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R // first try get existing user userQuery := models.GetUserByLoginQuery{LoginOrEmail: inviteDto.LoginOrEmail} - if err := bus.Dispatch(&userQuery); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil { if !errors.Is(err, models.ErrUserNotFound) { return response.Error(500, "Failed to query db for existing user check", err) } diff --git a/pkg/api/password.go b/pkg/api/password.go index 4aa338a1fd7..6d529f6df88 100644 --- a/pkg/api/password.go +++ b/pkg/api/password.go @@ -21,7 +21,7 @@ func SendResetPasswordEmail(c *models.ReqContext, form dtos.SendResetPasswordEma userQuery := models.GetUserByLoginQuery{LoginOrEmail: form.UserOrEmail} - if err := bus.Dispatch(&userQuery); err != nil { + if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil { c.Logger.Info("Requested password reset for user that was not found", "user", userQuery.LoginOrEmail) return response.Error(200, "Email sent", err) } diff --git a/pkg/api/signup.go b/pkg/api/signup.go index 6b9331874d0..bbeab002ccb 100644 --- a/pkg/api/signup.go +++ b/pkg/api/signup.go @@ -29,7 +29,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response { } existing := models.GetUserByLoginQuery{LoginOrEmail: form.Email} - if err := bus.Dispatch(&existing); err == nil { + if err := bus.DispatchCtx(c.Req.Context(), &existing); err == nil { return response.Error(422, "User with same email address already exists", nil) } diff --git a/pkg/login/auth.go b/pkg/login/auth.go index d6fbe157bfa..5f33d24f7a6 100644 --- a/pkg/login/auth.go +++ b/pkg/login/auth.go @@ -1,6 +1,7 @@ package login import ( + "context" "errors" "github.com/grafana/grafana/pkg/bus" @@ -25,12 +26,12 @@ var ( var loginLogger = log.New("login") func Init() { - bus.AddHandler("auth", authenticateUser) + bus.AddHandlerCtx("auth", authenticateUser) } // authenticateUser authenticates the user via username & password -func authenticateUser(query *models.LoginUserQuery) error { - if err := validateLoginAttempts(query); err != nil { +func authenticateUser(ctx context.Context, query *models.LoginUserQuery) error { + if err := validateLoginAttempts(ctx, query); err != nil { return err } @@ -38,14 +39,14 @@ func authenticateUser(query *models.LoginUserQuery) error { return err } - err := loginUsingGrafanaDB(query) + err := loginUsingGrafanaDB(ctx, query) if err == nil || (!errors.Is(err, models.ErrUserNotFound) && !errors.Is(err, ErrInvalidCredentials) && !errors.Is(err, ErrUserDisabled)) { query.AuthModule = "grafana" return err } - ldapEnabled, ldapErr := loginUsingLDAP(query) + ldapEnabled, ldapErr := loginUsingLDAP(ctx, query) if ldapEnabled { query.AuthModule = models.AuthModuleLDAP if ldapErr == nil || !errors.Is(ldapErr, ldap.ErrInvalidCredentials) { @@ -58,7 +59,7 @@ func authenticateUser(query *models.LoginUserQuery) error { } if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) { - if err := saveInvalidLoginAttempt(query); err != nil { + if err := saveInvalidLoginAttempt(ctx, query); err != nil { loginLogger.Error("Failed to save invalid login attempt", "err", err) } diff --git a/pkg/login/auth_test.go b/pkg/login/auth_test.go index e2175daffdf..ce3567997fb 100644 --- a/pkg/login/auth_test.go +++ b/pkg/login/auth_test.go @@ -1,6 +1,7 @@ package login import ( + "context" "errors" "testing" @@ -20,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) { Username: "user", Password: "", } - err := authenticateUser(&loginQuery) + err := authenticateUser(context.Background(), &loginQuery) require.EqualError(t, err, ErrPasswordEmpty.Error()) assert.False(t, sc.grafanaLoginWasCalled) @@ -34,7 +35,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrTooManyLoginAttempts.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -50,7 +51,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -67,7 +68,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, customErr.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -83,7 +84,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(false, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, models.ErrUserNotFound.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -99,7 +100,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrInvalidCredentials.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -115,7 +116,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -132,7 +133,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, customErr, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, customErr.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -148,7 +149,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(sc.loginUserQuery) + err := authenticateUser(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrInvalidCredentials.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -169,28 +170,28 @@ type authScenarioContext struct { type authScenarioFunc func(sc *authScenarioContext) func mockLoginUsingGrafanaDB(err error, sc *authScenarioContext) { - loginUsingGrafanaDB = func(query *models.LoginUserQuery) error { + loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error { sc.grafanaLoginWasCalled = true return err } } func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) { - loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { + loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) { sc.ldapLoginWasCalled = true return enabled, err } } func mockLoginAttemptValidation(err error, sc *authScenarioContext) { - validateLoginAttempts = func(*models.LoginUserQuery) error { + validateLoginAttempts = func(context.Context, *models.LoginUserQuery) error { sc.loginAttemptValidationWasCalled = true return err } } func mockSaveInvalidLoginAttempt(sc *authScenarioContext) { - saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { + saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error { sc.saveInvalidLoginAttemptWasCalled = true return nil } diff --git a/pkg/login/brute_force_login_protection.go b/pkg/login/brute_force_login_protection.go index 3d914b48e45..2026ff6a076 100644 --- a/pkg/login/brute_force_login_protection.go +++ b/pkg/login/brute_force_login_protection.go @@ -1,6 +1,7 @@ package login import ( + "context" "time" "github.com/grafana/grafana/pkg/bus" @@ -12,7 +13,7 @@ var ( loginAttemptsWindow = time.Minute * 5 ) -var validateLoginAttempts = func(query *models.LoginUserQuery) error { +var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery) error { if query.Cfg.DisableBruteForceLoginProtection { return nil } @@ -22,7 +23,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error { Since: time.Now().Add(-loginAttemptsWindow), } - if err := bus.Dispatch(&loginAttemptCountQuery); err != nil { + if err := bus.DispatchCtx(ctx, &loginAttemptCountQuery); err != nil { return err } @@ -33,7 +34,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error { return nil } -var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { +var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error { if query.Cfg.DisableBruteForceLoginProtection { return nil } @@ -43,5 +44,5 @@ var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { IpAddress: query.IpAddress, } - return bus.Dispatch(&loginAttemptCommand) + return bus.DispatchCtx(ctx, &loginAttemptCommand) } diff --git a/pkg/login/brute_force_login_protection_test.go b/pkg/login/brute_force_login_protection_test.go index f8d9cfb2fdf..cc01089f9a5 100644 --- a/pkg/login/brute_force_login_protection_test.go +++ b/pkg/login/brute_force_login_protection_test.go @@ -1,6 +1,7 @@ package login import ( + "context" "testing" "github.com/grafana/grafana/pkg/bus" @@ -61,7 +62,7 @@ func TestValidateLoginAttempts(t *testing.T) { withLoginAttempts(t, tc.loginAttempts) query := &models.LoginUserQuery{Username: "user", Cfg: tc.cfg} - err := validateLoginAttempts(query) + err := validateLoginAttempts(context.Background(), query) require.Equal(t, tc.expected, err) }) } @@ -77,7 +78,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) { return nil }) - err := saveInvalidLoginAttempt(&models.LoginUserQuery{ + err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{ Username: "user", Password: "pwd", IpAddress: "192.168.1.1:56433", @@ -99,7 +100,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) { return nil }) - err := saveInvalidLoginAttempt(&models.LoginUserQuery{ + err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{ Username: "user", Password: "pwd", IpAddress: "192.168.1.1:56433", diff --git a/pkg/login/grafana_login.go b/pkg/login/grafana_login.go index 42f2c788c2b..ab1acdb9d27 100644 --- a/pkg/login/grafana_login.go +++ b/pkg/login/grafana_login.go @@ -1,6 +1,7 @@ package login import ( + "context" "crypto/subtle" "github.com/grafana/grafana/pkg/bus" @@ -20,10 +21,10 @@ var validatePassword = func(providedPassword string, userPassword string, userSa return nil } -var loginUsingGrafanaDB = func(query *models.LoginUserQuery) error { +var loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error { userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.Username} - if err := bus.Dispatch(&userQuery); err != nil { + if err := bus.DispatchCtx(ctx, &userQuery); err != nil { return err } diff --git a/pkg/login/grafana_login_test.go b/pkg/login/grafana_login_test.go index 528827e53a5..bba7125f2ea 100644 --- a/pkg/login/grafana_login_test.go +++ b/pkg/login/grafana_login_test.go @@ -1,6 +1,7 @@ package login import ( + "context" "testing" "github.com/grafana/grafana/pkg/bus" @@ -12,7 +13,7 @@ import ( func TestLoginUsingGrafanaDB(t *testing.T) { grafanaLoginScenario(t, "When login with non-existing user", func(sc *grafanaLoginScenarioContext) { sc.withNonExistingUser() - err := loginUsingGrafanaDB(sc.loginUserQuery) + err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery) require.EqualError(t, err, models.ErrUserNotFound.Error()) assert.False(t, sc.validatePasswordCalled) @@ -21,7 +22,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) { grafanaLoginScenario(t, "When login with invalid credentials", func(sc *grafanaLoginScenarioContext) { sc.withInvalidPassword() - err := loginUsingGrafanaDB(sc.loginUserQuery) + err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrInvalidCredentials.Error()) @@ -31,7 +32,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) { grafanaLoginScenario(t, "When login with valid credentials", func(sc *grafanaLoginScenarioContext) { sc.withValidCredentials() - err := loginUsingGrafanaDB(sc.loginUserQuery) + err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.True(t, sc.validatePasswordCalled) @@ -43,7 +44,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) { grafanaLoginScenario(t, "When login with disabled user", func(sc *grafanaLoginScenarioContext) { sc.withDisabledUser() - err := loginUsingGrafanaDB(sc.loginUserQuery) + err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrUserDisabled.Error()) assert.False(t, sc.validatePasswordCalled) diff --git a/pkg/login/ldap_login.go b/pkg/login/ldap_login.go index cb5d984e736..0cf5e6d8c7a 100644 --- a/pkg/login/ldap_login.go +++ b/pkg/login/ldap_login.go @@ -1,6 +1,7 @@ package login import ( + "context" "errors" "github.com/grafana/grafana/pkg/bus" @@ -26,7 +27,7 @@ var ldapLogger = log.New("login.ldap") // loginUsingLDAP logs in user using LDAP. It returns whether LDAP is enabled and optional error and query arg will be // populated with the logged in user if successful. -var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { +var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) { enabled := isLDAPEnabled() if !enabled { @@ -57,7 +58,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { ExternalUser: externalUser, SignupAllowed: setting.LDAPAllowSignup, } - err = bus.Dispatch(upsert) + err = bus.DispatchCtx(ctx, upsert) if err != nil { return true, err } diff --git a/pkg/login/ldap_login_test.go b/pkg/login/ldap_login_test.go index e36c5dd5b5a..25c53bada07 100644 --- a/pkg/login/ldap_login_test.go +++ b/pkg/login/ldap_login_test.go @@ -1,6 +1,7 @@ package login import ( + "context" "errors" "testing" @@ -27,7 +28,7 @@ func TestLoginUsingLDAP(t *testing.T) { return config, nil } - enabled, err := loginUsingLDAP(sc.loginUserQuery) + enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery) require.EqualError(t, err, errTest.Error()) assert.True(t, enabled) @@ -38,7 +39,7 @@ func TestLoginUsingLDAP(t *testing.T) { setting.LDAPEnabled = false sc.withLoginResult(false) - enabled, err := loginUsingLDAP(sc.loginUserQuery) + enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.False(t, enabled) diff --git a/pkg/services/cleanup/cleanup.go b/pkg/services/cleanup/cleanup.go index 48ec74a3c4a..4a8468429ab 100644 --- a/pkg/services/cleanup/cleanup.go +++ b/pkg/services/cleanup/cleanup.go @@ -50,11 +50,11 @@ func (srv *CleanUpService) Run(ctx context.Context) error { srv.deleteExpiredSnapshots() srv.deleteExpiredDashboardVersions() srv.cleanUpOldAnnotations(ctxWithTimeout) - srv.expireOldUserInvites() - srv.deleteStaleShortURLs() + srv.expireOldUserInvites(ctx) + srv.deleteStaleShortURLs(ctx) err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts", time.Minute*10, func(context.Context) { - srv.deleteOldLoginAttempts() + srv.deleteOldLoginAttempts(ctx) }) if err != nil { srv.log.Error("failed to lock and execute cleanup of old login attempts", "error", err) @@ -143,7 +143,7 @@ func (srv *CleanUpService) deleteExpiredDashboardVersions() { } } -func (srv *CleanUpService) deleteOldLoginAttempts() { +func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) { if srv.Cfg.DisableBruteForceLoginProtection { return } @@ -151,31 +151,31 @@ func (srv *CleanUpService) deleteOldLoginAttempts() { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: time.Now().Add(time.Minute * -10), } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(ctx, &cmd); err != nil { srv.log.Error("Problem deleting expired login attempts", "error", err.Error()) } else { srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows) } } -func (srv *CleanUpService) expireOldUserInvites() { +func (srv *CleanUpService) expireOldUserInvites(ctx context.Context) { maxInviteLifetime := srv.Cfg.UserInviteMaxLifetime cmd := models.ExpireTempUsersCommand{ OlderThan: time.Now().Add(-maxInviteLifetime), } - if err := bus.Dispatch(&cmd); err != nil { + if err := bus.DispatchCtx(ctx, &cmd); err != nil { srv.log.Error("Problem expiring user invites", "error", err.Error()) } else { srv.log.Debug("Expired user invites", "rows affected", cmd.NumExpired) } } -func (srv *CleanUpService) deleteStaleShortURLs() { +func (srv *CleanUpService) deleteStaleShortURLs(ctx context.Context) { cmd := models.DeleteShortUrlCommand{ OlderThan: time.Now().Add(-time.Hour * 24 * 7), } - if err := srv.ShortURLService.DeleteStaleShortURLs(context.Background(), &cmd); err != nil { + if err := srv.ShortURLService.DeleteStaleShortURLs(ctx, &cmd); err != nil { srv.log.Error("Problem deleting stale short urls", "error", err.Error()) } else { srv.log.Debug("Deleted short urls", "rows affected", cmd.NumDeleted) diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 49e0256aab1..8a8699890ce 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -272,7 +272,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext, Password: password, Cfg: h.Cfg, } - if err := bus.Dispatch(&authQuery); err != nil { + if err := bus.DispatchCtx(reqContext.Req.Context(), &authQuery); err != nil { reqContext.Logger.Debug( "Failed to authorize the user", "username", username, diff --git a/pkg/services/login/authinfoservice/database.go b/pkg/services/login/authinfoservice/database.go index a59183f5169..25acc107980 100644 --- a/pkg/services/login/authinfoservice/database.go +++ b/pkg/services/login/authinfoservice/database.go @@ -13,15 +13,15 @@ import ( var getTime = time.Now -func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error { +func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error { userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail} - err := s.Bus.Dispatch(&userQuery) + err := s.Bus.DispatchCtx(ctx, &userQuery) if err != nil { return err } authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id} - if err := s.Bus.Dispatch(authInfoQuery); err != nil { + if err := s.Bus.DispatchCtx(context.TODO(), authInfoQuery); err != nil { return err } diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index 8dd5cdd04d7..096dd796832 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -32,7 +32,7 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio logger: log.New("login.authinfo"), } - s.Bus.AddHandler(s.GetExternalUserInfoByLogin) + s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin) s.Bus.AddHandler(s.GetAuthInfo) s.Bus.AddHandler(s.SetAuthInfo) s.Bus.AddHandler(s.UpdateAuthInfo) diff --git a/pkg/services/notifications/notifications.go b/pkg/services/notifications/notifications.go index 4cc2be70859..f3eb08e4adb 100644 --- a/pkg/services/notifications/notifications.go +++ b/pkg/services/notifications/notifications.go @@ -32,7 +32,7 @@ func ProvideService(bus bus.Bus, cfg *setting.Cfg) (*NotificationService, error) } ns.Bus.AddHandler(ns.sendResetPasswordEmail) - ns.Bus.AddHandler(ns.validateResetPasswordCode) + ns.Bus.AddHandlerCtx(ns.validateResetPasswordCode) ns.Bus.AddHandler(ns.sendEmailCommandHandler) ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync) @@ -163,14 +163,14 @@ func (ns *NotificationService) sendResetPasswordEmail(cmd *models.SendResetPassw }) } -func (ns *NotificationService) validateResetPasswordCode(query *models.ValidateResetPasswordCodeQuery) error { +func (ns *NotificationService) validateResetPasswordCode(ctx context.Context, query *models.ValidateResetPasswordCodeQuery) error { login := getLoginForEmailCode(query.Code) if login == "" { return models.ErrInvalidEmailCode } userQuery := models.GetUserByLoginQuery{LoginOrEmail: login} - if err := bus.Dispatch(&userQuery); err != nil { + if err := bus.DispatchCtx(ctx, &userQuery); err != nil { return err } diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index b55f24e98f6..8eb2aa51841 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -38,7 +38,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI } authInfoQuery := &models.GetAuthInfoQuery{UserId: user.UserId} - if err := bus.Dispatch(authInfoQuery); err != nil { + if err := bus.DispatchCtx(ctx, authInfoQuery); err != nil { if errors.Is(err, models.ErrUserNotFound) { // Not necessarily an error. User may be logged in another way. logger.Debug("no OAuth token for user found", "userId", user.UserId, "username", user.Login) diff --git a/pkg/services/sqlstore/login_attempt.go b/pkg/services/sqlstore/login_attempt.go index 169708be0a9..2a14c285711 100644 --- a/pkg/services/sqlstore/login_attempt.go +++ b/pkg/services/sqlstore/login_attempt.go @@ -1,6 +1,7 @@ package sqlstore import ( + "context" "strconv" "time" @@ -11,12 +12,12 @@ import ( var getTimeNow = time.Now func init() { - bus.AddHandler("sql", CreateLoginAttempt) - bus.AddHandler("sql", DeleteOldLoginAttempts) - bus.AddHandler("sql", GetUserLoginAttemptCount) + bus.AddHandlerCtx("sql", CreateLoginAttempt) + bus.AddHandlerCtx("sql", DeleteOldLoginAttempts) + bus.AddHandlerCtx("sql", GetUserLoginAttemptCount) } -func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error { +func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error { return inTransaction(func(sess *DBSession) error { loginAttempt := models.LoginAttempt{ Username: cmd.Username, @@ -34,7 +35,7 @@ func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error { }) } -func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error { +func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error { return inTransaction(func(sess *DBSession) error { var maxId int64 sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?" @@ -64,7 +65,7 @@ func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error { }) } -func GetUserLoginAttemptCount(query *models.GetUserLoginAttemptCountQuery) error { +func GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error { loginAttempt := new(models.LoginAttempt) total, err := x. Where("username = ?", query.Username). diff --git a/pkg/services/sqlstore/login_attempt_test.go b/pkg/services/sqlstore/login_attempt_test.go index 42ec4b40caa..dbe16866a38 100644 --- a/pkg/services/sqlstore/login_attempt_test.go +++ b/pkg/services/sqlstore/login_attempt_test.go @@ -4,6 +4,7 @@ package sqlstore import ( + "context" "testing" "time" @@ -24,19 +25,19 @@ func TestLoginAttempts(t *testing.T) { setup := func(t *testing.T) { InitTestDB(t) beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)) - err := CreateLoginAttempt(&models.CreateLoginAttemptCommand{ + err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) require.Nil(t, err) timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1)) - err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{ + err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) require.Nil(t, err) timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2)) - err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{ + err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{ Username: user, IpAddress: "192.168.0.1", }) @@ -49,7 +50,7 @@ func TestLoginAttempts(t *testing.T) { Username: user, Since: timePlusTwoMinutes.Add(time.Second * 1), } - err := GetUserLoginAttemptCount(&query) + err := GetUserLoginAttemptCount(context.Background(), &query) require.Nil(t, err) require.Equal(t, int64(0), query.Result) }) @@ -60,7 +61,7 @@ func TestLoginAttempts(t *testing.T) { Username: user, Since: beginningOfTime, } - err := GetUserLoginAttemptCount(&query) + err := GetUserLoginAttemptCount(context.Background(), &query) require.Nil(t, err) require.Equal(t, int64(3), query.Result) }) @@ -71,7 +72,7 @@ func TestLoginAttempts(t *testing.T) { Username: user, Since: timePlusOneMinute, } - err := GetUserLoginAttemptCount(&query) + err := GetUserLoginAttemptCount(context.Background(), &query) require.Nil(t, err) require.Equal(t, int64(2), query.Result) }) @@ -82,7 +83,7 @@ func TestLoginAttempts(t *testing.T) { Username: user, Since: timePlusTwoMinutes, } - err := GetUserLoginAttemptCount(&query) + err := GetUserLoginAttemptCount(context.Background(), &query) require.Nil(t, err) require.Equal(t, int64(1), query.Result) }) @@ -92,7 +93,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: beginningOfTime, } - err := DeleteOldLoginAttempts(&cmd) + err := DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(0), cmd.DeletedRows) @@ -103,7 +104,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusOneMinute, } - err := DeleteOldLoginAttempts(&cmd) + err := DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(1), cmd.DeletedRows) @@ -114,7 +115,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusTwoMinutes, } - err := DeleteOldLoginAttempts(&cmd) + err := DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(2), cmd.DeletedRows) @@ -125,7 +126,7 @@ func TestLoginAttempts(t *testing.T) { cmd := models.DeleteOldLoginAttemptsCommand{ OlderThan: timePlusTwoMinutes.Add(time.Second * 1), } - err := DeleteOldLoginAttempts(&cmd) + err := DeleteOldLoginAttempts(context.Background(), &cmd) require.Nil(t, err) require.Equal(t, int64(3), cmd.DeletedRows)