Chore: add context to login (#41316)

* Chore: add context to login attempt file and tests

* Chore: add context

* Chore: add context to login and login tests

* Chore: continue adding context to login

* Chore: add context to login query
This commit is contained in:
Katarina Yang 2021-11-08 09:53:51 -05:00 committed by GitHub
parent b58cca5d51
commit c4306f9b3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 90 additions and 78 deletions

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -123,7 +124,7 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API // External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not disable external user", nil) return response.Error(500, "Could not disable external user", nil)
} }
@ -149,7 +150,7 @@ func AdminEnableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API // External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID} authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) { if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not enable external user", nil) return response.Error(500, "Could not enable external user", nil)
} }

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -176,7 +177,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.Id, AuthModule: models.AuthModuleLDAP} authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.Id, AuthModule: models.AuthModuleLDAP}
if err := bus.Dispatch(authModuleQuery); err != nil { // validate the userId comes from LDAP if err := bus.DispatchCtx(context.TODO(), authModuleQuery); err != nil { // validate the userId comes from LDAP
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil) return response.Error(404, models.ErrUserNotFound.Error(), nil)
} }

View File

@ -208,7 +208,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response {
Cfg: hs.Cfg, Cfg: hs.Cfg,
} }
err := bus.Dispatch(authQuery) err := bus.DispatchCtx(c.Req.Context(), authQuery)
authModule = authQuery.AuthModule authModule = authQuery.AuthModule
if err != nil { if err != nil {
resp = response.Error(401, "Invalid username or password", err) resp = response.Error(401, "Invalid username or password", err)

View File

@ -37,7 +37,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
// first try get existing user // first try get existing user
userQuery := models.GetUserByLoginQuery{LoginOrEmail: inviteDto.LoginOrEmail} userQuery := models.GetUserByLoginQuery{LoginOrEmail: inviteDto.LoginOrEmail}
if err := bus.Dispatch(&userQuery); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
if !errors.Is(err, models.ErrUserNotFound) { if !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Failed to query db for existing user check", err) return response.Error(500, "Failed to query db for existing user check", err)
} }

View File

@ -21,7 +21,7 @@ func SendResetPasswordEmail(c *models.ReqContext, form dtos.SendResetPasswordEma
userQuery := models.GetUserByLoginQuery{LoginOrEmail: form.UserOrEmail} userQuery := models.GetUserByLoginQuery{LoginOrEmail: form.UserOrEmail}
if err := bus.Dispatch(&userQuery); err != nil { if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
c.Logger.Info("Requested password reset for user that was not found", "user", userQuery.LoginOrEmail) c.Logger.Info("Requested password reset for user that was not found", "user", userQuery.LoginOrEmail)
return response.Error(200, "Email sent", err) return response.Error(200, "Email sent", err)
} }

View File

@ -29,7 +29,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response {
} }
existing := models.GetUserByLoginQuery{LoginOrEmail: form.Email} existing := models.GetUserByLoginQuery{LoginOrEmail: form.Email}
if err := bus.Dispatch(&existing); err == nil { if err := bus.DispatchCtx(c.Req.Context(), &existing); err == nil {
return response.Error(422, "User with same email address already exists", nil) return response.Error(422, "User with same email address already exists", nil)
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -25,12 +26,12 @@ var (
var loginLogger = log.New("login") var loginLogger = log.New("login")
func Init() { func Init() {
bus.AddHandler("auth", authenticateUser) bus.AddHandlerCtx("auth", authenticateUser)
} }
// authenticateUser authenticates the user via username & password // authenticateUser authenticates the user via username & password
func authenticateUser(query *models.LoginUserQuery) error { func authenticateUser(ctx context.Context, query *models.LoginUserQuery) error {
if err := validateLoginAttempts(query); err != nil { if err := validateLoginAttempts(ctx, query); err != nil {
return err return err
} }
@ -38,14 +39,14 @@ func authenticateUser(query *models.LoginUserQuery) error {
return err return err
} }
err := loginUsingGrafanaDB(query) err := loginUsingGrafanaDB(ctx, query)
if err == nil || (!errors.Is(err, models.ErrUserNotFound) && !errors.Is(err, ErrInvalidCredentials) && if err == nil || (!errors.Is(err, models.ErrUserNotFound) && !errors.Is(err, ErrInvalidCredentials) &&
!errors.Is(err, ErrUserDisabled)) { !errors.Is(err, ErrUserDisabled)) {
query.AuthModule = "grafana" query.AuthModule = "grafana"
return err return err
} }
ldapEnabled, ldapErr := loginUsingLDAP(query) ldapEnabled, ldapErr := loginUsingLDAP(ctx, query)
if ldapEnabled { if ldapEnabled {
query.AuthModule = models.AuthModuleLDAP query.AuthModule = models.AuthModuleLDAP
if ldapErr == nil || !errors.Is(ldapErr, ldap.ErrInvalidCredentials) { if ldapErr == nil || !errors.Is(ldapErr, ldap.ErrInvalidCredentials) {
@ -58,7 +59,7 @@ func authenticateUser(query *models.LoginUserQuery) error {
} }
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) {
if err := saveInvalidLoginAttempt(query); err != nil { if err := saveInvalidLoginAttempt(ctx, query); err != nil {
loginLogger.Error("Failed to save invalid login attempt", "err", err) loginLogger.Error("Failed to save invalid login attempt", "err", err)
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -20,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) {
Username: "user", Username: "user",
Password: "", Password: "",
} }
err := authenticateUser(&loginQuery) err := authenticateUser(context.Background(), &loginQuery)
require.EqualError(t, err, ErrPasswordEmpty.Error()) require.EqualError(t, err, ErrPasswordEmpty.Error())
assert.False(t, sc.grafanaLoginWasCalled) assert.False(t, sc.grafanaLoginWasCalled)
@ -34,7 +35,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc) mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrTooManyLoginAttempts.Error()) require.EqualError(t, err, ErrTooManyLoginAttempts.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -50,7 +51,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -67,7 +68,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error()) require.EqualError(t, err, customErr.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -83,7 +84,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(false, nil, sc) mockLoginUsingLDAP(false, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, models.ErrUserNotFound.Error()) require.EqualError(t, err, models.ErrUserNotFound.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -99,7 +100,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error()) require.EqualError(t, err, ErrInvalidCredentials.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -115,7 +116,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc) mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -132,7 +133,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, customErr, sc) mockLoginUsingLDAP(true, customErr, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error()) require.EqualError(t, err, customErr.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -148,7 +149,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery) err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error()) require.EqualError(t, err, ErrInvalidCredentials.Error())
assert.True(t, sc.loginAttemptValidationWasCalled) assert.True(t, sc.loginAttemptValidationWasCalled)
@ -169,28 +170,28 @@ type authScenarioContext struct {
type authScenarioFunc func(sc *authScenarioContext) type authScenarioFunc func(sc *authScenarioContext)
func mockLoginUsingGrafanaDB(err error, sc *authScenarioContext) { func mockLoginUsingGrafanaDB(err error, sc *authScenarioContext) {
loginUsingGrafanaDB = func(query *models.LoginUserQuery) error { loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
sc.grafanaLoginWasCalled = true sc.grafanaLoginWasCalled = true
return err return err
} }
} }
func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) { func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
sc.ldapLoginWasCalled = true sc.ldapLoginWasCalled = true
return enabled, err return enabled, err
} }
} }
func mockLoginAttemptValidation(err error, sc *authScenarioContext) { func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
validateLoginAttempts = func(*models.LoginUserQuery) error { validateLoginAttempts = func(context.Context, *models.LoginUserQuery) error {
sc.loginAttemptValidationWasCalled = true sc.loginAttemptValidationWasCalled = true
return err return err
} }
} }
func mockSaveInvalidLoginAttempt(sc *authScenarioContext) { func mockSaveInvalidLoginAttempt(sc *authScenarioContext) {
saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
sc.saveInvalidLoginAttemptWasCalled = true sc.saveInvalidLoginAttemptWasCalled = true
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"time" "time"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -12,7 +13,7 @@ var (
loginAttemptsWindow = time.Minute * 5 loginAttemptsWindow = time.Minute * 5
) )
var validateLoginAttempts = func(query *models.LoginUserQuery) error { var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery) error {
if query.Cfg.DisableBruteForceLoginProtection { if query.Cfg.DisableBruteForceLoginProtection {
return nil return nil
} }
@ -22,7 +23,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
Since: time.Now().Add(-loginAttemptsWindow), Since: time.Now().Add(-loginAttemptsWindow),
} }
if err := bus.Dispatch(&loginAttemptCountQuery); err != nil { if err := bus.DispatchCtx(ctx, &loginAttemptCountQuery); err != nil {
return err return err
} }
@ -33,7 +34,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
return nil return nil
} }
var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
if query.Cfg.DisableBruteForceLoginProtection { if query.Cfg.DisableBruteForceLoginProtection {
return nil return nil
} }
@ -43,5 +44,5 @@ var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
IpAddress: query.IpAddress, IpAddress: query.IpAddress,
} }
return bus.Dispatch(&loginAttemptCommand) return bus.DispatchCtx(ctx, &loginAttemptCommand)
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"testing" "testing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -61,7 +62,7 @@ func TestValidateLoginAttempts(t *testing.T) {
withLoginAttempts(t, tc.loginAttempts) withLoginAttempts(t, tc.loginAttempts)
query := &models.LoginUserQuery{Username: "user", Cfg: tc.cfg} query := &models.LoginUserQuery{Username: "user", Cfg: tc.cfg}
err := validateLoginAttempts(query) err := validateLoginAttempts(context.Background(), query)
require.Equal(t, tc.expected, err) require.Equal(t, tc.expected, err)
}) })
} }
@ -77,7 +78,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
return nil return nil
}) })
err := saveInvalidLoginAttempt(&models.LoginUserQuery{ err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
Username: "user", Username: "user",
Password: "pwd", Password: "pwd",
IpAddress: "192.168.1.1:56433", IpAddress: "192.168.1.1:56433",
@ -99,7 +100,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
return nil return nil
}) })
err := saveInvalidLoginAttempt(&models.LoginUserQuery{ err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
Username: "user", Username: "user",
Password: "pwd", Password: "pwd",
IpAddress: "192.168.1.1:56433", IpAddress: "192.168.1.1:56433",

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"crypto/subtle" "crypto/subtle"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -20,10 +21,10 @@ var validatePassword = func(providedPassword string, userPassword string, userSa
return nil return nil
} }
var loginUsingGrafanaDB = func(query *models.LoginUserQuery) error { var loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.Username} userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.Username}
if err := bus.Dispatch(&userQuery); err != nil { if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"testing" "testing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -12,7 +13,7 @@ import (
func TestLoginUsingGrafanaDB(t *testing.T) { func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with non-existing user", func(sc *grafanaLoginScenarioContext) { grafanaLoginScenario(t, "When login with non-existing user", func(sc *grafanaLoginScenarioContext) {
sc.withNonExistingUser() sc.withNonExistingUser()
err := loginUsingGrafanaDB(sc.loginUserQuery) err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, models.ErrUserNotFound.Error()) require.EqualError(t, err, models.ErrUserNotFound.Error())
assert.False(t, sc.validatePasswordCalled) assert.False(t, sc.validatePasswordCalled)
@ -21,7 +22,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with invalid credentials", func(sc *grafanaLoginScenarioContext) { grafanaLoginScenario(t, "When login with invalid credentials", func(sc *grafanaLoginScenarioContext) {
sc.withInvalidPassword() sc.withInvalidPassword()
err := loginUsingGrafanaDB(sc.loginUserQuery) err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error()) require.EqualError(t, err, ErrInvalidCredentials.Error())
@ -31,7 +32,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with valid credentials", func(sc *grafanaLoginScenarioContext) { grafanaLoginScenario(t, "When login with valid credentials", func(sc *grafanaLoginScenarioContext) {
sc.withValidCredentials() sc.withValidCredentials()
err := loginUsingGrafanaDB(sc.loginUserQuery) err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, sc.validatePasswordCalled) assert.True(t, sc.validatePasswordCalled)
@ -43,7 +44,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with disabled user", func(sc *grafanaLoginScenarioContext) { grafanaLoginScenario(t, "When login with disabled user", func(sc *grafanaLoginScenarioContext) {
sc.withDisabledUser() sc.withDisabledUser()
err := loginUsingGrafanaDB(sc.loginUserQuery) err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrUserDisabled.Error()) require.EqualError(t, err, ErrUserDisabled.Error())
assert.False(t, sc.validatePasswordCalled) assert.False(t, sc.validatePasswordCalled)

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
@ -26,7 +27,7 @@ var ldapLogger = log.New("login.ldap")
// loginUsingLDAP logs in user using LDAP. It returns whether LDAP is enabled and optional error and query arg will be // loginUsingLDAP logs in user using LDAP. It returns whether LDAP is enabled and optional error and query arg will be
// populated with the logged in user if successful. // populated with the logged in user if successful.
var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
enabled := isLDAPEnabled() enabled := isLDAPEnabled()
if !enabled { if !enabled {
@ -57,7 +58,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
ExternalUser: externalUser, ExternalUser: externalUser,
SignupAllowed: setting.LDAPAllowSignup, SignupAllowed: setting.LDAPAllowSignup,
} }
err = bus.Dispatch(upsert) err = bus.DispatchCtx(ctx, upsert)
if err != nil { if err != nil {
return true, err return true, err
} }

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"context"
"errors" "errors"
"testing" "testing"
@ -27,7 +28,7 @@ func TestLoginUsingLDAP(t *testing.T) {
return config, nil return config, nil
} }
enabled, err := loginUsingLDAP(sc.loginUserQuery) enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, errTest.Error()) require.EqualError(t, err, errTest.Error())
assert.True(t, enabled) assert.True(t, enabled)
@ -38,7 +39,7 @@ func TestLoginUsingLDAP(t *testing.T) {
setting.LDAPEnabled = false setting.LDAPEnabled = false
sc.withLoginResult(false) sc.withLoginResult(false)
enabled, err := loginUsingLDAP(sc.loginUserQuery) enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, enabled) assert.False(t, enabled)

View File

@ -50,11 +50,11 @@ func (srv *CleanUpService) Run(ctx context.Context) error {
srv.deleteExpiredSnapshots() srv.deleteExpiredSnapshots()
srv.deleteExpiredDashboardVersions() srv.deleteExpiredDashboardVersions()
srv.cleanUpOldAnnotations(ctxWithTimeout) srv.cleanUpOldAnnotations(ctxWithTimeout)
srv.expireOldUserInvites() srv.expireOldUserInvites(ctx)
srv.deleteStaleShortURLs() srv.deleteStaleShortURLs(ctx)
err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts", err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts",
time.Minute*10, func(context.Context) { time.Minute*10, func(context.Context) {
srv.deleteOldLoginAttempts() srv.deleteOldLoginAttempts(ctx)
}) })
if err != nil { if err != nil {
srv.log.Error("failed to lock and execute cleanup of old login attempts", "error", err) srv.log.Error("failed to lock and execute cleanup of old login attempts", "error", err)
@ -143,7 +143,7 @@ func (srv *CleanUpService) deleteExpiredDashboardVersions() {
} }
} }
func (srv *CleanUpService) deleteOldLoginAttempts() { func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) {
if srv.Cfg.DisableBruteForceLoginProtection { if srv.Cfg.DisableBruteForceLoginProtection {
return return
} }
@ -151,31 +151,31 @@ func (srv *CleanUpService) deleteOldLoginAttempts() {
cmd := models.DeleteOldLoginAttemptsCommand{ cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: time.Now().Add(time.Minute * -10), OlderThan: time.Now().Add(time.Minute * -10),
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(ctx, &cmd); err != nil {
srv.log.Error("Problem deleting expired login attempts", "error", err.Error()) srv.log.Error("Problem deleting expired login attempts", "error", err.Error())
} else { } else {
srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows) srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows)
} }
} }
func (srv *CleanUpService) expireOldUserInvites() { func (srv *CleanUpService) expireOldUserInvites(ctx context.Context) {
maxInviteLifetime := srv.Cfg.UserInviteMaxLifetime maxInviteLifetime := srv.Cfg.UserInviteMaxLifetime
cmd := models.ExpireTempUsersCommand{ cmd := models.ExpireTempUsersCommand{
OlderThan: time.Now().Add(-maxInviteLifetime), OlderThan: time.Now().Add(-maxInviteLifetime),
} }
if err := bus.Dispatch(&cmd); err != nil { if err := bus.DispatchCtx(ctx, &cmd); err != nil {
srv.log.Error("Problem expiring user invites", "error", err.Error()) srv.log.Error("Problem expiring user invites", "error", err.Error())
} else { } else {
srv.log.Debug("Expired user invites", "rows affected", cmd.NumExpired) srv.log.Debug("Expired user invites", "rows affected", cmd.NumExpired)
} }
} }
func (srv *CleanUpService) deleteStaleShortURLs() { func (srv *CleanUpService) deleteStaleShortURLs(ctx context.Context) {
cmd := models.DeleteShortUrlCommand{ cmd := models.DeleteShortUrlCommand{
OlderThan: time.Now().Add(-time.Hour * 24 * 7), OlderThan: time.Now().Add(-time.Hour * 24 * 7),
} }
if err := srv.ShortURLService.DeleteStaleShortURLs(context.Background(), &cmd); err != nil { if err := srv.ShortURLService.DeleteStaleShortURLs(ctx, &cmd); err != nil {
srv.log.Error("Problem deleting stale short urls", "error", err.Error()) srv.log.Error("Problem deleting stale short urls", "error", err.Error())
} else { } else {
srv.log.Debug("Deleted short urls", "rows affected", cmd.NumDeleted) srv.log.Debug("Deleted short urls", "rows affected", cmd.NumDeleted)

View File

@ -272,7 +272,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext,
Password: password, Password: password,
Cfg: h.Cfg, Cfg: h.Cfg,
} }
if err := bus.Dispatch(&authQuery); err != nil { if err := bus.DispatchCtx(reqContext.Req.Context(), &authQuery); err != nil {
reqContext.Logger.Debug( reqContext.Logger.Debug(
"Failed to authorize the user", "Failed to authorize the user",
"username", username, "username", username,

View File

@ -13,15 +13,15 @@ import (
var getTime = time.Now var getTime = time.Now
func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error { func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail} userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
err := s.Bus.Dispatch(&userQuery) err := s.Bus.DispatchCtx(ctx, &userQuery)
if err != nil { if err != nil {
return err return err
} }
authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id} authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id}
if err := s.Bus.Dispatch(authInfoQuery); err != nil { if err := s.Bus.DispatchCtx(context.TODO(), authInfoQuery); err != nil {
return err return err
} }

View File

@ -32,7 +32,7 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio
logger: log.New("login.authinfo"), logger: log.New("login.authinfo"),
} }
s.Bus.AddHandler(s.GetExternalUserInfoByLogin) s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin)
s.Bus.AddHandler(s.GetAuthInfo) s.Bus.AddHandler(s.GetAuthInfo)
s.Bus.AddHandler(s.SetAuthInfo) s.Bus.AddHandler(s.SetAuthInfo)
s.Bus.AddHandler(s.UpdateAuthInfo) s.Bus.AddHandler(s.UpdateAuthInfo)

View File

@ -32,7 +32,7 @@ func ProvideService(bus bus.Bus, cfg *setting.Cfg) (*NotificationService, error)
} }
ns.Bus.AddHandler(ns.sendResetPasswordEmail) ns.Bus.AddHandler(ns.sendResetPasswordEmail)
ns.Bus.AddHandler(ns.validateResetPasswordCode) ns.Bus.AddHandlerCtx(ns.validateResetPasswordCode)
ns.Bus.AddHandler(ns.sendEmailCommandHandler) ns.Bus.AddHandler(ns.sendEmailCommandHandler)
ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync) ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync)
@ -163,14 +163,14 @@ func (ns *NotificationService) sendResetPasswordEmail(cmd *models.SendResetPassw
}) })
} }
func (ns *NotificationService) validateResetPasswordCode(query *models.ValidateResetPasswordCodeQuery) error { func (ns *NotificationService) validateResetPasswordCode(ctx context.Context, query *models.ValidateResetPasswordCodeQuery) error {
login := getLoginForEmailCode(query.Code) login := getLoginForEmailCode(query.Code)
if login == "" { if login == "" {
return models.ErrInvalidEmailCode return models.ErrInvalidEmailCode
} }
userQuery := models.GetUserByLoginQuery{LoginOrEmail: login} userQuery := models.GetUserByLoginQuery{LoginOrEmail: login}
if err := bus.Dispatch(&userQuery); err != nil { if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
return err return err
} }

View File

@ -38,7 +38,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
} }
authInfoQuery := &models.GetAuthInfoQuery{UserId: user.UserId} authInfoQuery := &models.GetAuthInfoQuery{UserId: user.UserId}
if err := bus.Dispatch(authInfoQuery); err != nil { if err := bus.DispatchCtx(ctx, authInfoQuery); err != nil {
if errors.Is(err, models.ErrUserNotFound) { if errors.Is(err, models.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way. // Not necessarily an error. User may be logged in another way.
logger.Debug("no OAuth token for user found", "userId", user.UserId, "username", user.Login) logger.Debug("no OAuth token for user found", "userId", user.UserId, "username", user.Login)

View File

@ -1,6 +1,7 @@
package sqlstore package sqlstore
import ( import (
"context"
"strconv" "strconv"
"time" "time"
@ -11,12 +12,12 @@ import (
var getTimeNow = time.Now var getTimeNow = time.Now
func init() { func init() {
bus.AddHandler("sql", CreateLoginAttempt) bus.AddHandlerCtx("sql", CreateLoginAttempt)
bus.AddHandler("sql", DeleteOldLoginAttempts) bus.AddHandlerCtx("sql", DeleteOldLoginAttempts)
bus.AddHandler("sql", GetUserLoginAttemptCount) bus.AddHandlerCtx("sql", GetUserLoginAttemptCount)
} }
func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error { func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
loginAttempt := models.LoginAttempt{ loginAttempt := models.LoginAttempt{
Username: cmd.Username, Username: cmd.Username,
@ -34,7 +35,7 @@ func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error {
}) })
} }
func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error { func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
var maxId int64 var maxId int64
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?" sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
@ -64,7 +65,7 @@ func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error {
}) })
} }
func GetUserLoginAttemptCount(query *models.GetUserLoginAttemptCountQuery) error { func GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
loginAttempt := new(models.LoginAttempt) loginAttempt := new(models.LoginAttempt)
total, err := x. total, err := x.
Where("username = ?", query.Username). Where("username = ?", query.Username).

View File

@ -4,6 +4,7 @@
package sqlstore package sqlstore
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -24,19 +25,19 @@ func TestLoginAttempts(t *testing.T) {
setup := func(t *testing.T) { setup := func(t *testing.T) {
InitTestDB(t) InitTestDB(t)
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)) beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
err := CreateLoginAttempt(&models.CreateLoginAttemptCommand{ err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user, Username: user,
IpAddress: "192.168.0.1", IpAddress: "192.168.0.1",
}) })
require.Nil(t, err) require.Nil(t, err)
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1)) timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{ err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user, Username: user,
IpAddress: "192.168.0.1", IpAddress: "192.168.0.1",
}) })
require.Nil(t, err) require.Nil(t, err)
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2)) timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{ err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user, Username: user,
IpAddress: "192.168.0.1", IpAddress: "192.168.0.1",
}) })
@ -49,7 +50,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user, Username: user,
Since: timePlusTwoMinutes.Add(time.Second * 1), Since: timePlusTwoMinutes.Add(time.Second * 1),
} }
err := GetUserLoginAttemptCount(&query) err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), query.Result) require.Equal(t, int64(0), query.Result)
}) })
@ -60,7 +61,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user, Username: user,
Since: beginningOfTime, Since: beginningOfTime,
} }
err := GetUserLoginAttemptCount(&query) err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(3), query.Result) require.Equal(t, int64(3), query.Result)
}) })
@ -71,7 +72,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user, Username: user,
Since: timePlusOneMinute, Since: timePlusOneMinute,
} }
err := GetUserLoginAttemptCount(&query) err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(2), query.Result) require.Equal(t, int64(2), query.Result)
}) })
@ -82,7 +83,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user, Username: user,
Since: timePlusTwoMinutes, Since: timePlusTwoMinutes,
} }
err := GetUserLoginAttemptCount(&query) err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(1), query.Result) require.Equal(t, int64(1), query.Result)
}) })
@ -92,7 +93,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{ cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: beginningOfTime, OlderThan: beginningOfTime,
} }
err := DeleteOldLoginAttempts(&cmd) err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(0), cmd.DeletedRows) require.Equal(t, int64(0), cmd.DeletedRows)
@ -103,7 +104,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{ cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusOneMinute, OlderThan: timePlusOneMinute,
} }
err := DeleteOldLoginAttempts(&cmd) err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(1), cmd.DeletedRows) require.Equal(t, int64(1), cmd.DeletedRows)
@ -114,7 +115,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{ cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes, OlderThan: timePlusTwoMinutes,
} }
err := DeleteOldLoginAttempts(&cmd) err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(2), cmd.DeletedRows) require.Equal(t, int64(2), cmd.DeletedRows)
@ -125,7 +126,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{ cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes.Add(time.Second * 1), OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
} }
err := DeleteOldLoginAttempts(&cmd) err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, int64(3), cmd.DeletedRows) require.Equal(t, int64(3), cmd.DeletedRows)