Chore: add context to login (#41316)

* Chore: add context to login attempt file and tests

* Chore: add context

* Chore: add context to login and login tests

* Chore: continue adding context to login

* Chore: add context to login query
This commit is contained in:
Katarina Yang 2021-11-08 09:53:51 -05:00 committed by GitHub
parent b58cca5d51
commit c4306f9b3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 90 additions and 78 deletions

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"fmt"
@ -123,7 +124,7 @@ func (hs *HTTPServer) AdminDisableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not disable external user", nil)
}
@ -149,7 +150,7 @@ func AdminEnableUser(c *models.ReqContext) response.Response {
// External users shouldn't be disabled from API
authInfoQuery := &models.GetAuthInfoQuery{UserId: userID}
if err := bus.Dispatch(authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
if err := bus.DispatchCtx(context.TODO(), authInfoQuery); !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Could not enable external user", nil)
}

View File

@ -1,6 +1,7 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
@ -176,7 +177,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.Id, AuthModule: models.AuthModuleLDAP}
if err := bus.Dispatch(authModuleQuery); err != nil { // validate the userId comes from LDAP
if err := bus.DispatchCtx(context.TODO(), authModuleQuery); err != nil { // validate the userId comes from LDAP
if errors.Is(err, models.ErrUserNotFound) {
return response.Error(404, models.ErrUserNotFound.Error(), nil)
}

View File

@ -208,7 +208,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response {
Cfg: hs.Cfg,
}
err := bus.Dispatch(authQuery)
err := bus.DispatchCtx(c.Req.Context(), authQuery)
authModule = authQuery.AuthModule
if err != nil {
resp = response.Error(401, "Invalid username or password", err)

View File

@ -37,7 +37,7 @@ func AddOrgInvite(c *models.ReqContext, inviteDto dtos.AddInviteForm) response.R
// first try get existing user
userQuery := models.GetUserByLoginQuery{LoginOrEmail: inviteDto.LoginOrEmail}
if err := bus.Dispatch(&userQuery); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
if !errors.Is(err, models.ErrUserNotFound) {
return response.Error(500, "Failed to query db for existing user check", err)
}

View File

@ -21,7 +21,7 @@ func SendResetPasswordEmail(c *models.ReqContext, form dtos.SendResetPasswordEma
userQuery := models.GetUserByLoginQuery{LoginOrEmail: form.UserOrEmail}
if err := bus.Dispatch(&userQuery); err != nil {
if err := bus.DispatchCtx(c.Req.Context(), &userQuery); err != nil {
c.Logger.Info("Requested password reset for user that was not found", "user", userQuery.LoginOrEmail)
return response.Error(200, "Email sent", err)
}

View File

@ -29,7 +29,7 @@ func SignUp(c *models.ReqContext, form dtos.SignUpForm) response.Response {
}
existing := models.GetUserByLoginQuery{LoginOrEmail: form.Email}
if err := bus.Dispatch(&existing); err == nil {
if err := bus.DispatchCtx(c.Req.Context(), &existing); err == nil {
return response.Error(422, "User with same email address already exists", nil)
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"errors"
"github.com/grafana/grafana/pkg/bus"
@ -25,12 +26,12 @@ var (
var loginLogger = log.New("login")
func Init() {
bus.AddHandler("auth", authenticateUser)
bus.AddHandlerCtx("auth", authenticateUser)
}
// authenticateUser authenticates the user via username & password
func authenticateUser(query *models.LoginUserQuery) error {
if err := validateLoginAttempts(query); err != nil {
func authenticateUser(ctx context.Context, query *models.LoginUserQuery) error {
if err := validateLoginAttempts(ctx, query); err != nil {
return err
}
@ -38,14 +39,14 @@ func authenticateUser(query *models.LoginUserQuery) error {
return err
}
err := loginUsingGrafanaDB(query)
err := loginUsingGrafanaDB(ctx, query)
if err == nil || (!errors.Is(err, models.ErrUserNotFound) && !errors.Is(err, ErrInvalidCredentials) &&
!errors.Is(err, ErrUserDisabled)) {
query.AuthModule = "grafana"
return err
}
ldapEnabled, ldapErr := loginUsingLDAP(query)
ldapEnabled, ldapErr := loginUsingLDAP(ctx, query)
if ldapEnabled {
query.AuthModule = models.AuthModuleLDAP
if ldapErr == nil || !errors.Is(ldapErr, ldap.ErrInvalidCredentials) {
@ -58,7 +59,7 @@ func authenticateUser(query *models.LoginUserQuery) error {
}
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) {
if err := saveInvalidLoginAttempt(query); err != nil {
if err := saveInvalidLoginAttempt(ctx, query); err != nil {
loginLogger.Error("Failed to save invalid login attempt", "err", err)
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"errors"
"testing"
@ -20,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) {
Username: "user",
Password: "",
}
err := authenticateUser(&loginQuery)
err := authenticateUser(context.Background(), &loginQuery)
require.EqualError(t, err, ErrPasswordEmpty.Error())
assert.False(t, sc.grafanaLoginWasCalled)
@ -34,7 +35,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrTooManyLoginAttempts.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -50,7 +51,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -67,7 +68,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -83,7 +84,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(false, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, models.ErrUserNotFound.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -99,7 +100,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -115,7 +116,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -132,7 +133,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, customErr, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -148,7 +149,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := authenticateUser(sc.loginUserQuery)
err := authenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error())
assert.True(t, sc.loginAttemptValidationWasCalled)
@ -169,28 +170,28 @@ type authScenarioContext struct {
type authScenarioFunc func(sc *authScenarioContext)
func mockLoginUsingGrafanaDB(err error, sc *authScenarioContext) {
loginUsingGrafanaDB = func(query *models.LoginUserQuery) error {
loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
sc.grafanaLoginWasCalled = true
return err
}
}
func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
sc.ldapLoginWasCalled = true
return enabled, err
}
}
func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
validateLoginAttempts = func(*models.LoginUserQuery) error {
validateLoginAttempts = func(context.Context, *models.LoginUserQuery) error {
sc.loginAttemptValidationWasCalled = true
return err
}
}
func mockSaveInvalidLoginAttempt(sc *authScenarioContext) {
saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
sc.saveInvalidLoginAttemptWasCalled = true
return nil
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"time"
"github.com/grafana/grafana/pkg/bus"
@ -12,7 +13,7 @@ var (
loginAttemptsWindow = time.Minute * 5
)
var validateLoginAttempts = func(query *models.LoginUserQuery) error {
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery) error {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
@ -22,7 +23,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
Since: time.Now().Add(-loginAttemptsWindow),
}
if err := bus.Dispatch(&loginAttemptCountQuery); err != nil {
if err := bus.DispatchCtx(ctx, &loginAttemptCountQuery); err != nil {
return err
}
@ -33,7 +34,7 @@ var validateLoginAttempts = func(query *models.LoginUserQuery) error {
return nil
}
var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery) error {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
@ -43,5 +44,5 @@ var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
IpAddress: query.IpAddress,
}
return bus.Dispatch(&loginAttemptCommand)
return bus.DispatchCtx(ctx, &loginAttemptCommand)
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"testing"
"github.com/grafana/grafana/pkg/bus"
@ -61,7 +62,7 @@ func TestValidateLoginAttempts(t *testing.T) {
withLoginAttempts(t, tc.loginAttempts)
query := &models.LoginUserQuery{Username: "user", Cfg: tc.cfg}
err := validateLoginAttempts(query)
err := validateLoginAttempts(context.Background(), query)
require.Equal(t, tc.expected, err)
})
}
@ -77,7 +78,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
return nil
})
err := saveInvalidLoginAttempt(&models.LoginUserQuery{
err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
Username: "user",
Password: "pwd",
IpAddress: "192.168.1.1:56433",
@ -99,7 +100,7 @@ func TestSaveInvalidLoginAttempt(t *testing.T) {
return nil
})
err := saveInvalidLoginAttempt(&models.LoginUserQuery{
err := saveInvalidLoginAttempt(context.Background(), &models.LoginUserQuery{
Username: "user",
Password: "pwd",
IpAddress: "192.168.1.1:56433",

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"crypto/subtle"
"github.com/grafana/grafana/pkg/bus"
@ -20,10 +21,10 @@ var validatePassword = func(providedPassword string, userPassword string, userSa
return nil
}
var loginUsingGrafanaDB = func(query *models.LoginUserQuery) error {
var loginUsingGrafanaDB = func(ctx context.Context, query *models.LoginUserQuery) error {
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.Username}
if err := bus.Dispatch(&userQuery); err != nil {
if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"testing"
"github.com/grafana/grafana/pkg/bus"
@ -12,7 +13,7 @@ import (
func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with non-existing user", func(sc *grafanaLoginScenarioContext) {
sc.withNonExistingUser()
err := loginUsingGrafanaDB(sc.loginUserQuery)
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, models.ErrUserNotFound.Error())
assert.False(t, sc.validatePasswordCalled)
@ -21,7 +22,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with invalid credentials", func(sc *grafanaLoginScenarioContext) {
sc.withInvalidPassword()
err := loginUsingGrafanaDB(sc.loginUserQuery)
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error())
@ -31,7 +32,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with valid credentials", func(sc *grafanaLoginScenarioContext) {
sc.withValidCredentials()
err := loginUsingGrafanaDB(sc.loginUserQuery)
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
assert.True(t, sc.validatePasswordCalled)
@ -43,7 +44,7 @@ func TestLoginUsingGrafanaDB(t *testing.T) {
grafanaLoginScenario(t, "When login with disabled user", func(sc *grafanaLoginScenarioContext) {
sc.withDisabledUser()
err := loginUsingGrafanaDB(sc.loginUserQuery)
err := loginUsingGrafanaDB(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrUserDisabled.Error())
assert.False(t, sc.validatePasswordCalled)

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"errors"
"github.com/grafana/grafana/pkg/bus"
@ -26,7 +27,7 @@ var ldapLogger = log.New("login.ldap")
// loginUsingLDAP logs in user using LDAP. It returns whether LDAP is enabled and optional error and query arg will be
// populated with the logged in user if successful.
var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
var loginUsingLDAP = func(ctx context.Context, query *models.LoginUserQuery) (bool, error) {
enabled := isLDAPEnabled()
if !enabled {
@ -57,7 +58,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
ExternalUser: externalUser,
SignupAllowed: setting.LDAPAllowSignup,
}
err = bus.Dispatch(upsert)
err = bus.DispatchCtx(ctx, upsert)
if err != nil {
return true, err
}

View File

@ -1,6 +1,7 @@
package login
import (
"context"
"errors"
"testing"
@ -27,7 +28,7 @@ func TestLoginUsingLDAP(t *testing.T) {
return config, nil
}
enabled, err := loginUsingLDAP(sc.loginUserQuery)
enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, errTest.Error())
assert.True(t, enabled)
@ -38,7 +39,7 @@ func TestLoginUsingLDAP(t *testing.T) {
setting.LDAPEnabled = false
sc.withLoginResult(false)
enabled, err := loginUsingLDAP(sc.loginUserQuery)
enabled, err := loginUsingLDAP(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
assert.False(t, enabled)

View File

@ -50,11 +50,11 @@ func (srv *CleanUpService) Run(ctx context.Context) error {
srv.deleteExpiredSnapshots()
srv.deleteExpiredDashboardVersions()
srv.cleanUpOldAnnotations(ctxWithTimeout)
srv.expireOldUserInvites()
srv.deleteStaleShortURLs()
srv.expireOldUserInvites(ctx)
srv.deleteStaleShortURLs(ctx)
err := srv.ServerLockService.LockAndExecute(ctx, "delete old login attempts",
time.Minute*10, func(context.Context) {
srv.deleteOldLoginAttempts()
srv.deleteOldLoginAttempts(ctx)
})
if err != nil {
srv.log.Error("failed to lock and execute cleanup of old login attempts", "error", err)
@ -143,7 +143,7 @@ func (srv *CleanUpService) deleteExpiredDashboardVersions() {
}
}
func (srv *CleanUpService) deleteOldLoginAttempts() {
func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) {
if srv.Cfg.DisableBruteForceLoginProtection {
return
}
@ -151,31 +151,31 @@ func (srv *CleanUpService) deleteOldLoginAttempts() {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: time.Now().Add(time.Minute * -10),
}
if err := bus.Dispatch(&cmd); err != nil {
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
srv.log.Error("Problem deleting expired login attempts", "error", err.Error())
} else {
srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows)
}
}
func (srv *CleanUpService) expireOldUserInvites() {
func (srv *CleanUpService) expireOldUserInvites(ctx context.Context) {
maxInviteLifetime := srv.Cfg.UserInviteMaxLifetime
cmd := models.ExpireTempUsersCommand{
OlderThan: time.Now().Add(-maxInviteLifetime),
}
if err := bus.Dispatch(&cmd); err != nil {
if err := bus.DispatchCtx(ctx, &cmd); err != nil {
srv.log.Error("Problem expiring user invites", "error", err.Error())
} else {
srv.log.Debug("Expired user invites", "rows affected", cmd.NumExpired)
}
}
func (srv *CleanUpService) deleteStaleShortURLs() {
func (srv *CleanUpService) deleteStaleShortURLs(ctx context.Context) {
cmd := models.DeleteShortUrlCommand{
OlderThan: time.Now().Add(-time.Hour * 24 * 7),
}
if err := srv.ShortURLService.DeleteStaleShortURLs(context.Background(), &cmd); err != nil {
if err := srv.ShortURLService.DeleteStaleShortURLs(ctx, &cmd); err != nil {
srv.log.Error("Problem deleting stale short urls", "error", err.Error())
} else {
srv.log.Debug("Deleted short urls", "rows affected", cmd.NumDeleted)

View File

@ -272,7 +272,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext,
Password: password,
Cfg: h.Cfg,
}
if err := bus.Dispatch(&authQuery); err != nil {
if err := bus.DispatchCtx(reqContext.Req.Context(), &authQuery); err != nil {
reqContext.Logger.Debug(
"Failed to authorize the user",
"username", username,

View File

@ -13,15 +13,15 @@ import (
var getTime = time.Now
func (s *Implementation) GetExternalUserInfoByLogin(query *models.GetExternalUserInfoByLoginQuery) error {
func (s *Implementation) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
userQuery := models.GetUserByLoginQuery{LoginOrEmail: query.LoginOrEmail}
err := s.Bus.Dispatch(&userQuery)
err := s.Bus.DispatchCtx(ctx, &userQuery)
if err != nil {
return err
}
authInfoQuery := &models.GetAuthInfoQuery{UserId: userQuery.Result.Id}
if err := s.Bus.Dispatch(authInfoQuery); err != nil {
if err := s.Bus.DispatchCtx(context.TODO(), authInfoQuery); err != nil {
return err
}

View File

@ -32,7 +32,7 @@ func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectio
logger: log.New("login.authinfo"),
}
s.Bus.AddHandler(s.GetExternalUserInfoByLogin)
s.Bus.AddHandlerCtx(s.GetExternalUserInfoByLogin)
s.Bus.AddHandler(s.GetAuthInfo)
s.Bus.AddHandler(s.SetAuthInfo)
s.Bus.AddHandler(s.UpdateAuthInfo)

View File

@ -32,7 +32,7 @@ func ProvideService(bus bus.Bus, cfg *setting.Cfg) (*NotificationService, error)
}
ns.Bus.AddHandler(ns.sendResetPasswordEmail)
ns.Bus.AddHandler(ns.validateResetPasswordCode)
ns.Bus.AddHandlerCtx(ns.validateResetPasswordCode)
ns.Bus.AddHandler(ns.sendEmailCommandHandler)
ns.Bus.AddHandlerCtx(ns.sendEmailCommandHandlerSync)
@ -163,14 +163,14 @@ func (ns *NotificationService) sendResetPasswordEmail(cmd *models.SendResetPassw
})
}
func (ns *NotificationService) validateResetPasswordCode(query *models.ValidateResetPasswordCodeQuery) error {
func (ns *NotificationService) validateResetPasswordCode(ctx context.Context, query *models.ValidateResetPasswordCodeQuery) error {
login := getLoginForEmailCode(query.Code)
if login == "" {
return models.ErrInvalidEmailCode
}
userQuery := models.GetUserByLoginQuery{LoginOrEmail: login}
if err := bus.Dispatch(&userQuery); err != nil {
if err := bus.DispatchCtx(ctx, &userQuery); err != nil {
return err
}

View File

@ -38,7 +38,7 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI
}
authInfoQuery := &models.GetAuthInfoQuery{UserId: user.UserId}
if err := bus.Dispatch(authInfoQuery); err != nil {
if err := bus.DispatchCtx(ctx, authInfoQuery); err != nil {
if errors.Is(err, models.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("no OAuth token for user found", "userId", user.UserId, "username", user.Login)

View File

@ -1,6 +1,7 @@
package sqlstore
import (
"context"
"strconv"
"time"
@ -11,12 +12,12 @@ import (
var getTimeNow = time.Now
func init() {
bus.AddHandler("sql", CreateLoginAttempt)
bus.AddHandler("sql", DeleteOldLoginAttempts)
bus.AddHandler("sql", GetUserLoginAttemptCount)
bus.AddHandlerCtx("sql", CreateLoginAttempt)
bus.AddHandlerCtx("sql", DeleteOldLoginAttempts)
bus.AddHandlerCtx("sql", GetUserLoginAttemptCount)
}
func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error {
func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return inTransaction(func(sess *DBSession) error {
loginAttempt := models.LoginAttempt{
Username: cmd.Username,
@ -34,7 +35,7 @@ func CreateLoginAttempt(cmd *models.CreateLoginAttemptCommand) error {
})
}
func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error {
func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return inTransaction(func(sess *DBSession) error {
var maxId int64
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
@ -64,7 +65,7 @@ func DeleteOldLoginAttempts(cmd *models.DeleteOldLoginAttemptsCommand) error {
})
}
func GetUserLoginAttemptCount(query *models.GetUserLoginAttemptCountQuery) error {
func GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
loginAttempt := new(models.LoginAttempt)
total, err := x.
Where("username = ?", query.Username).

View File

@ -4,6 +4,7 @@
package sqlstore
import (
"context"
"testing"
"time"
@ -24,19 +25,19 @@ func TestLoginAttempts(t *testing.T) {
setup := func(t *testing.T) {
InitTestDB(t)
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
err := CreateLoginAttempt(&models.CreateLoginAttemptCommand{
err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
err = CreateLoginAttempt(&models.CreateLoginAttemptCommand{
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
@ -49,7 +50,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user,
Since: timePlusTwoMinutes.Add(time.Second * 1),
}
err := GetUserLoginAttemptCount(&query)
err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(0), query.Result)
})
@ -60,7 +61,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user,
Since: beginningOfTime,
}
err := GetUserLoginAttemptCount(&query)
err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(3), query.Result)
})
@ -71,7 +72,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user,
Since: timePlusOneMinute,
}
err := GetUserLoginAttemptCount(&query)
err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(2), query.Result)
})
@ -82,7 +83,7 @@ func TestLoginAttempts(t *testing.T) {
Username: user,
Since: timePlusTwoMinutes,
}
err := GetUserLoginAttemptCount(&query)
err := GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(1), query.Result)
})
@ -92,7 +93,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: beginningOfTime,
}
err := DeleteOldLoginAttempts(&cmd)
err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(0), cmd.DeletedRows)
@ -103,7 +104,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusOneMinute,
}
err := DeleteOldLoginAttempts(&cmd)
err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(1), cmd.DeletedRows)
@ -114,7 +115,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes,
}
err := DeleteOldLoginAttempts(&cmd)
err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(2), cmd.DeletedRows)
@ -125,7 +126,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
}
err := DeleteOldLoginAttempts(&cmd)
err := DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(3), cmd.DeletedRows)