mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: add context to login (#41316)
* Chore: add context to login attempt file and tests * Chore: add context * Chore: add context to login and login tests * Chore: continue adding context to login * Chore: add context to login query
This commit is contained in:
parent
b58cca5d51
commit
c4306f9b3e
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
@ -123,7 +124,7 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response {
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
|
||||
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not disable external user", nil)
|
||||
}
|
||||
|
||||
@ -149,7 +150,7 @@ func AdminEnableUser(c *models.ReqContext) response.Response {
|
||||
|
||||
// External users shouldn't be disabled from API
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
|
||||
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(500, "Could not enable external user", nil)
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -176,7 +177,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
|
||||
|
||||
authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.Id, AuthModule: models.AuthModuleLDAP}
|
||||
|
||||
if err := bus.Dispatch(authModuleQuery); err != nil { // validate the userId comes from LDAP
|
||||
if err := bus.DispatchCtx(context.TODO(), authModuleQuery); err != nil { // validate the userId comes from LDAP
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(404, models.ErrUserNotFound.Error(), nil)
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response {
|
||||
Cfg: hs.Cfg,
|
||||
}
|
||||
|
||||
err := bus.Dispatch(authQuery)
|
||||
err := bus.DispatchCtx(c.Req.Context(), authQuery)
|
||||
authModule = authQuery.AuthModule
|
||||
if err != nil {
|
||||
resp = response.Error(401, "Invalid username or password", err)
|
||||
|
@ -37,7 +37,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
|
||||
|
||||
// first try get existing user
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: inviteDto.LoginOrEmail}
|
||||
if err := bus.Dispatch(&userQuery); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
|
||||
if !errors.Is(err, models.ErrUserNotFound) {
|
||||
return response.Error(500, "Failed to query db for existing user check", err)
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ func SendResetPasswordEmail(c *models.ReqContext, form dtos.SendResetPasswordEma
|
||||
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: form.UserOrEmail}
|
||||
|
||||
if err := bus.Dispatch(&userQuery); err != nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
|
||||
c.Logger.Info("Requested password reset for user that was not found", "user", userQuery.LoginOrEmail)
|
||||
return response.Error(200, "Email sent", err)
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response {
|
||||
}
|
||||
|
||||
existing := models.GetUserByLoginQuery{LoginOrEmail: form.Email}
|
||||
if err := bus.Dispatch(&existing); err == nil {
|
||||
if err := bus.DispatchCtx(c.Req.Context(), &existing); err == nil {
|
||||
return response.Error(422, "User with same email address already exists", nil)
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -25,12 +26,12 @@ var (
|
||||
var loginLogger = log.New("login")
|
||||
|
||||
func Init() {
|
||||
bus.AddHandler("auth", authenticateUser)
|
||||
bus.AddHandlerCtx("auth", authenticateUser)
|
||||
}
|
||||
|
||||
// authenticateUser authenticates the user via username & password
|
||||
func authenticateUser(query *models.LoginUserQuery) error {
|
||||
if err := validateLoginAttempts(query); err != nil {
|
||||
func authenticateUser(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
if err := validateLoginAttempts(ctx, query); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -38,14 +39,14 @@ func authenticateUser(query *models.LoginUserQuery) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err := loginUsingGrafanaDB(query)
|
||||
err := loginUsingGrafanaDB(ctx, query)
|
||||
if err == nil || (!errors.Is(err, models.ErrUserNotFound) && !errors.Is(err, ErrInvalidCredentials) &&
|
||||
!errors.Is(err, ErrUserDisabled)) {
|
||||
query.AuthModule = "grafana"
|
||||
return err
|
||||
}
|
||||
|
||||
ldapEnabled, ldapErr := loginUsingLDAP(query)
|
||||
ldapEnabled, ldapErr := loginUsingLDAP(ctx, query)
|
||||
if ldapEnabled {
|
||||
query.AuthModule = models.AuthModuleLDAP
|
||||
if ldapErr == nil || !errors.Is(ldapErr, ldap.ErrInvalidCredentials) {
|
||||
@ -58,7 +59,7 @@ func authenticateUser(query *models.LoginUserQuery) error {
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) {
|
||||
if err := saveInvalidLoginAttempt(query); err != nil {
|
||||
if err := saveInvalidLoginAttempt(ctx, query); err != nil {
|
||||
loginLogger.Error("Failed to save invalid login attempt", "err", err)
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@ -20,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
Username: "user",
|
||||
Password: "",
|
||||
}
|
||||
err := authenticateUser(&loginQuery)
|
||||
err := authenticateUser(context.Background(), &loginQuery)
|
||||
|
||||
require.EqualError(t, err, ErrPasswordEmpty.Error())
|
||||
assert.False(t, sc.grafanaLoginWasCalled)
|
||||
@ -34,7 +35,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrTooManyLoginAttempts.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -50,7 +51,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -67,7 +68,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, customErr.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -83,7 +84,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(false, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, models.ErrUserNotFound.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -99,7 +100,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrInvalidCredentials.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -115,7 +116,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -132,7 +133,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, customErr, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, customErr.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -148,7 +149,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
err := authenticateUser(sc.loginUserQuery)
|
||||
err := authenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrInvalidCredentials.Error())
|
||||
assert.True(t, sc.loginAttemptValidationWasCalled)
|
||||
@ -169,28 +170,28 @@ type authScenarioContext struct {
|
||||
type authScenarioFunc func(sc *authScenarioContext)
|
||||
|
||||
func mockLoginUsingGrafanaDB(err error, sc *authScenarioContext) {
|
||||
loginUsingGrafanaDB = func(query *models.LoginUserQuery) error {
|
||||
loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
sc.grafanaLoginWasCalled = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
|
||||
loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
|
||||
loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
|
||||
sc.ldapLoginWasCalled = true
|
||||
return enabled, err
|
||||
}
|
||||
}
|
||||
|
||||
func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
|
||||
validateLoginAttempts = func(*models.LoginUserQuery) error {
|
||||
validateLoginAttempts = func(context.Context, *models.LoginUserQuery) error {
|
||||
sc.loginAttemptValidationWasCalled = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func mockSaveInvalidLoginAttempt(sc *authScenarioContext) {
|
||||
saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
|
||||
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
sc.saveInvalidLoginAttemptWasCalled = true
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -12,7 +13,7 @@ var (
|
||||
loginAttemptsWindow = time.Minute * 5
|
||||
)
|
||||
|
||||
var validateLoginAttempts = func(query *models.LoginUserQuery) error {
|
||||
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
if query.Cfg.DisableBruteForceLoginProtection {
|
||||
return nil
|
||||
}
|
||||
@ -22,7 +23,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
|
||||
Since: time.Now().Add(-loginAttemptsWindow),
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(&loginAttemptCountQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, &loginAttemptCountQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -33,7 +34,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
|
||||
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
if query.Cfg.DisableBruteForceLoginProtection {
|
||||
return nil
|
||||
}
|
||||
@ -43,5 +44,5 @@ var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
|
||||
IpAddress: query.IpAddress,
|
||||
}
|
||||
|
||||
return bus.Dispatch(&loginAttemptCommand)
|
||||
return bus.DispatchCtx(ctx, &loginAttemptCommand)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -61,7 +62,7 @@ func TestValidateLoginAttempts(t *testing.T) {
|
||||
withLoginAttempts(t, tc.loginAttempts)
|
||||
|
||||
query := &models.LoginUserQuery{Username: "user", Cfg: tc.cfg}
|
||||
err := validateLoginAttempts(query)
|
||||
err := validateLoginAttempts(context.Background(), query)
|
||||
require.Equal(t, tc.expected, err)
|
||||
})
|
||||
}
|
||||
@ -77,7 +78,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
err := saveInvalidLoginAttempt(&models.LoginUserQuery{
|
||||
err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
|
||||
Username: "user",
|
||||
Password: "pwd",
|
||||
IpAddress: "192.168.1.1:56433",
|
||||
@ -99,7 +100,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
err := saveInvalidLoginAttempt(&models.LoginUserQuery{
|
||||
err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
|
||||
Username: "user",
|
||||
Password: "pwd",
|
||||
IpAddress: "192.168.1.1:56433",
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -20,10 +21,10 @@ var validatePassword = func(providedPassword string, userPassword string, userSa
|
||||
return nil
|
||||
}
|
||||
|
||||
var loginUsingGrafanaDB = func(query *models.LoginUserQuery) error {
|
||||
var loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.Username}
|
||||
|
||||
if err := bus.Dispatch(&userQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -12,7 +13,7 @@ import (
|
||||
func TestLoginUsingGrafanaDB(t *testing.T) {
|
||||
grafanaLoginScenario(t, "When login with non-existing user", func(sc *grafanaLoginScenarioContext) {
|
||||
sc.withNonExistingUser()
|
||||
err := loginUsingGrafanaDB(sc.loginUserQuery)
|
||||
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
|
||||
require.EqualError(t, err, models.ErrUserNotFound.Error())
|
||||
|
||||
assert.False(t, sc.validatePasswordCalled)
|
||||
@ -21,7 +22,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
|
||||
|
||||
grafanaLoginScenario(t, "When login with invalid credentials", func(sc *grafanaLoginScenarioContext) {
|
||||
sc.withInvalidPassword()
|
||||
err := loginUsingGrafanaDB(sc.loginUserQuery)
|
||||
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrInvalidCredentials.Error())
|
||||
|
||||
@ -31,7 +32,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
|
||||
|
||||
grafanaLoginScenario(t, "When login with valid credentials", func(sc *grafanaLoginScenarioContext) {
|
||||
sc.withValidCredentials()
|
||||
err := loginUsingGrafanaDB(sc.loginUserQuery)
|
||||
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, sc.validatePasswordCalled)
|
||||
@ -43,7 +44,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
|
||||
|
||||
grafanaLoginScenario(t, "When login with disabled user", func(sc *grafanaLoginScenarioContext) {
|
||||
sc.withDisabledUser()
|
||||
err := loginUsingGrafanaDB(sc.loginUserQuery)
|
||||
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
|
||||
require.EqualError(t, err, ErrUserDisabled.Error())
|
||||
|
||||
assert.False(t, sc.validatePasswordCalled)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@ -26,7 +27,7 @@ var ldapLogger = log.New("login.ldap")
|
||||
|
||||
// loginUsingLDAP logs in user using LDAP. It returns whether LDAP is enabled and optional error and query arg will be
|
||||
// populated with the logged in user if successful.
|
||||
var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
|
||||
var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
|
||||
enabled := isLDAPEnabled()
|
||||
|
||||
if !enabled {
|
||||
@ -57,7 +58,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
|
||||
ExternalUser: externalUser,
|
||||
SignupAllowed: setting.LDAPAllowSignup,
|
||||
}
|
||||
err = bus.Dispatch(upsert)
|
||||
err = bus.DispatchCtx(ctx, upsert)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@ -27,7 +28,7 @@ func TestLoginUsingLDAP(t *testing.T) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
enabled, err := loginUsingLDAP(sc.loginUserQuery)
|
||||
enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
|
||||
require.EqualError(t, err, errTest.Error())
|
||||
|
||||
assert.True(t, enabled)
|
||||
@ -38,7 +39,7 @@ func TestLoginUsingLDAP(t *testing.T) {
|
||||
setting.LDAPEnabled = false
|
||||
|
||||
sc.withLoginResult(false)
|
||||
enabled, err := loginUsingLDAP(sc.loginUserQuery)
|
||||
enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, enabled)
|
||||
|
@ -50,11 +50,11 @@ func (srv *CleanUpService) Run(ctx context.Context) error {
|
||||
srv.deleteExpiredSnapshots()
|
||||
srv.deleteExpiredDashboardVersions()
|
||||
srv.cleanUpOldAnnotations(ctxWithTimeout)
|
||||
srv.expireOldUserInvites()
|
||||
srv.deleteStaleShortURLs()
|
||||
srv.expireOldUserInvites(ctx)
|
||||
srv.deleteStaleShortURLs(ctx)
|
||||
err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts",
|
||||
time.Minute*10, func(context.Context) {
|
||||
srv.deleteOldLoginAttempts()
|
||||
srv.deleteOldLoginAttempts(ctx)
|
||||
})
|
||||
if err != nil {
|
||||
srv.log.Error("failed to lock and execute cleanup of old login attempts", "error", err)
|
||||
@ -143,7 +143,7 @@ func (srv *CleanUpService) deleteExpiredDashboardVersions() {
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *CleanUpService) deleteOldLoginAttempts() {
|
||||
func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) {
|
||||
if srv.Cfg.DisableBruteForceLoginProtection {
|
||||
return
|
||||
}
|
||||
@ -151,31 +151,31 @@ func (srv *CleanUpService) deleteOldLoginAttempts() {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: time.Now().Add(time.Minute * -10),
|
||||
}
|
||||
if err := bus.Dispatch(&cmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
|
||||
srv.log.Error("Problem deleting expired login attempts", "error", err.Error())
|
||||
} else {
|
||||
srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *CleanUpService) expireOldUserInvites() {
|
||||
func (srv *CleanUpService) expireOldUserInvites(ctx context.Context) {
|
||||
maxInviteLifetime := srv.Cfg.UserInviteMaxLifetime
|
||||
|
||||
cmd := models.ExpireTempUsersCommand{
|
||||
OlderThan: time.Now().Add(-maxInviteLifetime),
|
||||
}
|
||||
if err := bus.Dispatch(&cmd); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
|
||||
srv.log.Error("Problem expiring user invites", "error", err.Error())
|
||||
} else {
|
||||
srv.log.Debug("Expired user invites", "rows affected", cmd.NumExpired)
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *CleanUpService) deleteStaleShortURLs() {
|
||||
func (srv *CleanUpService) deleteStaleShortURLs(ctx context.Context) {
|
||||
cmd := models.DeleteShortUrlCommand{
|
||||
OlderThan: time.Now().Add(-time.Hour * 24 * 7),
|
||||
}
|
||||
if err := srv.ShortURLService.DeleteStaleShortURLs(context.Background(), &cmd); err != nil {
|
||||
if err := srv.ShortURLService.DeleteStaleShortURLs(ctx, &cmd); err != nil {
|
||||
srv.log.Error("Problem deleting stale short urls", "error", err.Error())
|
||||
} else {
|
||||
srv.log.Debug("Deleted short urls", "rows affected", cmd.NumDeleted)
|
||||
|
@ -272,7 +272,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext,
|
||||
Password: password,
|
||||
Cfg: h.Cfg,
|
||||
}
|
||||
if err := bus.Dispatch(&authQuery); err != nil {
|
||||
if err := bus.DispatchCtx(reqContext.Req.Context(), &authQuery); err != nil {
|
||||
reqContext.Logger.Debug(
|
||||
"Failed to authorize the user",
|
||||
"username", username,
|
||||
|
@ -13,15 +13,15 @@ import (
|
||||
|
||||
var getTime = time.Now
|
||||
|
||||
func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
|
||||
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
|
||||
err := s.Bus.Dispatch(&userQuery)
|
||||
err := s.Bus.DispatchCtx(ctx, &userQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id}
|
||||
if err := s.Bus.Dispatch(authInfoQuery); err != nil {
|
||||
if err := s.Bus.DispatchCtx(context.TODO(), authInfoQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio
|
||||
logger: log.New("login.authinfo"),
|
||||
}
|
||||
|
||||
s.Bus.AddHandler(s.GetExternalUserInfoByLogin)
|
||||
s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin)
|
||||
s.Bus.AddHandler(s.GetAuthInfo)
|
||||
s.Bus.AddHandler(s.SetAuthInfo)
|
||||
s.Bus.AddHandler(s.UpdateAuthInfo)
|
||||
|
@ -32,7 +32,7 @@ func ProvideService(bus bus.Bus, cfg *setting.Cfg) (*NotificationService, error)
|
||||
}
|
||||
|
||||
ns.Bus.AddHandler(ns.sendResetPasswordEmail)
|
||||
ns.Bus.AddHandler(ns.validateResetPasswordCode)
|
||||
ns.Bus.AddHandlerCtx(ns.validateResetPasswordCode)
|
||||
ns.Bus.AddHandler(ns.sendEmailCommandHandler)
|
||||
|
||||
ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync)
|
||||
@ -163,14 +163,14 @@ func (ns *NotificationService) sendResetPasswordEmail(cmd *models.SendResetPassw
|
||||
})
|
||||
}
|
||||
|
||||
func (ns *NotificationService) validateResetPasswordCode(query *models.ValidateResetPasswordCodeQuery) error {
|
||||
func (ns *NotificationService) validateResetPasswordCode(ctx context.Context, query *models.ValidateResetPasswordCodeQuery) error {
|
||||
login := getLoginForEmailCode(query.Code)
|
||||
if login == "" {
|
||||
return models.ErrInvalidEmailCode
|
||||
}
|
||||
|
||||
userQuery := models.GetUserByLoginQuery{LoginOrEmail: login}
|
||||
if err := bus.Dispatch(&userQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
|
||||
}
|
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: user.UserId}
|
||||
if err := bus.Dispatch(authInfoQuery); err != nil {
|
||||
if err := bus.DispatchCtx(ctx, authInfoQuery); err != nil {
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("no OAuth token for user found", "userId", user.UserId, "username", user.Login)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@ -11,12 +12,12 @@ import (
|
||||
var getTimeNow = time.Now
|
||||
|
||||
func init() {
|
||||
bus.AddHandler("sql", CreateLoginAttempt)
|
||||
bus.AddHandler("sql", DeleteOldLoginAttempts)
|
||||
bus.AddHandler("sql", GetUserLoginAttemptCount)
|
||||
bus.AddHandlerCtx("sql", CreateLoginAttempt)
|
||||
bus.AddHandlerCtx("sql", DeleteOldLoginAttempts)
|
||||
bus.AddHandlerCtx("sql", GetUserLoginAttemptCount)
|
||||
}
|
||||
|
||||
func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error {
|
||||
func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
loginAttempt := models.LoginAttempt{
|
||||
Username: cmd.Username,
|
||||
@ -34,7 +35,7 @@ func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error {
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
var maxId int64
|
||||
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
|
||||
@ -64,7 +65,7 @@ func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserLoginAttemptCount(query *models.GetUserLoginAttemptCountQuery) error {
|
||||
func GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
|
||||
loginAttempt := new(models.LoginAttempt)
|
||||
total, err := x.
|
||||
Where("username = ?", query.Username).
|
||||
|
@ -4,6 +4,7 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -24,19 +25,19 @@ func TestLoginAttempts(t *testing.T) {
|
||||
setup := func(t *testing.T) {
|
||||
InitTestDB(t)
|
||||
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
|
||||
err := CreateLoginAttempt(&models.CreateLoginAttemptCommand{
|
||||
err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
|
||||
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{
|
||||
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
|
||||
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{
|
||||
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
@ -49,7 +50,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
Username: user,
|
||||
Since: timePlusTwoMinutes.Add(time.Second * 1),
|
||||
}
|
||||
err := GetUserLoginAttemptCount(&query)
|
||||
err := GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), query.Result)
|
||||
})
|
||||
@ -60,7 +61,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
Username: user,
|
||||
Since: beginningOfTime,
|
||||
}
|
||||
err := GetUserLoginAttemptCount(&query)
|
||||
err := GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(3), query.Result)
|
||||
})
|
||||
@ -71,7 +72,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
Username: user,
|
||||
Since: timePlusOneMinute,
|
||||
}
|
||||
err := GetUserLoginAttemptCount(&query)
|
||||
err := GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), query.Result)
|
||||
})
|
||||
@ -82,7 +83,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
Username: user,
|
||||
Since: timePlusTwoMinutes,
|
||||
}
|
||||
err := GetUserLoginAttemptCount(&query)
|
||||
err := GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1), query.Result)
|
||||
})
|
||||
@ -92,7 +93,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: beginningOfTime,
|
||||
}
|
||||
err := DeleteOldLoginAttempts(&cmd)
|
||||
err := DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), cmd.DeletedRows)
|
||||
@ -103,7 +104,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusOneMinute,
|
||||
}
|
||||
err := DeleteOldLoginAttempts(&cmd)
|
||||
err := DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1), cmd.DeletedRows)
|
||||
@ -114,7 +115,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusTwoMinutes,
|
||||
}
|
||||
err := DeleteOldLoginAttempts(&cmd)
|
||||
err := DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), cmd.DeletedRows)
|
||||
@ -125,7 +126,7 @@ func TestLoginAttempts(t *testing.T) {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
|
||||
}
|
||||
err := DeleteOldLoginAttempts(&cmd)
|
||||
err := DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(3), cmd.DeletedRows)
|
||||
|
Loading…
Reference in New Issue
Block a user