Chore: Move login attempt methods to separate service (#54479)

* Chore: Move login attempt methods to separate service

* attempt to fix tests

* fix syntax

* better time mocking

* initialise now func
This commit is contained in:
Serge Zaitsev 2022-09-01 18:08:42 +02:00 committed by GitHub
parent d2bdb01092
commit 927ddf9376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 300 additions and 275 deletions

View File

@ -59,7 +59,7 @@ import (
"github.com/grafana/grafana/pkg/services/live"
"github.com/grafana/grafana/pkg/services/live/pushhttp"
"github.com/grafana/grafana/pkg/services/login"
loginAttempt "github.com/grafana/grafana/pkg/services/login_attempt"
loginAttempt "github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/ngalert"
"github.com/grafana/grafana/pkg/services/notifications"
"github.com/grafana/grafana/pkg/services/org"

View File

@ -77,6 +77,7 @@ import (
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
authinfodatabase "github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/loginattempt/loginattemptimpl"
"github.com/grafana/grafana/pkg/services/ngalert"
ngmetrics "github.com/grafana/grafana/pkg/services/ngalert/metrics"
"github.com/grafana/grafana/pkg/services/notifications"
@ -219,6 +220,7 @@ var wireSet = wire.NewSet(
authinfodatabase.ProvideAuthInfoStore,
loginpkg.ProvideService,
wire.Bind(new(loginpkg.Authenticator), new(*loginpkg.AuthenticatorService)),
loginattemptimpl.ProvideService,
datasourceproxy.ProvideService,
search.ProvideService,
searchV2.ProvideService,

View File

@ -8,6 +8,7 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
)
@ -32,23 +33,23 @@ type Authenticator interface {
}
type AuthenticatorService struct {
store sqlstore.Store
loginService login.Service
userService user.Service
loginService login.Service
loginAttemptService loginattempt.Service
userService user.Service
}
func ProvideService(store sqlstore.Store, loginService login.Service, userService user.Service) *AuthenticatorService {
func ProvideService(store sqlstore.Store, loginService login.Service, loginAttemptService loginattempt.Service, userService user.Service) *AuthenticatorService {
a := &AuthenticatorService{
store: store,
loginService: loginService,
userService: userService,
loginService: loginService,
loginAttemptService: loginAttemptService,
userService: userService,
}
return a
}
// AuthenticateUser authenticates the user via username & password
func (a *AuthenticatorService) AuthenticateUser(ctx context.Context, query *models.LoginUserQuery) error {
if err := validateLoginAttempts(ctx, query, a.store); err != nil {
if err := validateLoginAttempts(ctx, query, a.loginAttemptService); err != nil {
return err
}
@ -76,7 +77,7 @@ func (a *AuthenticatorService) AuthenticateUser(ctx context.Context, query *mode
}
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) {
if err := saveInvalidLoginAttempt(ctx, query, a.store); err != nil {
if err := saveInvalidLoginAttempt(ctx, query, a.loginAttemptService); err != nil {
loginLogger.Error("Failed to save invalid login attempt", "err", err)
}

View File

@ -9,8 +9,7 @@ import (
"github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/logintest"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -26,7 +25,7 @@ func TestAuthenticateUser(t *testing.T) {
Username: "user",
Password: "",
}
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginAttemptService: nil, loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), &loginQuery)
require.EqualError(t, err, ErrPasswordEmpty.Error())
@ -41,7 +40,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrTooManyLoginAttempts.Error())
@ -58,7 +57,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
@ -76,7 +75,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error())
@ -93,7 +92,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(false, nil, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, user.ErrUserNotFound.Error())
@ -110,7 +109,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error())
@ -127,7 +126,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.NoError(t, err)
@ -145,7 +144,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, customErr, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, customErr.Error())
@ -162,7 +161,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
require.EqualError(t, err, ErrInvalidCredentials.Error())
@ -198,14 +197,14 @@ func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
}
func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
validateLoginAttempts = func(context.Context, *models.LoginUserQuery, sqlstore.Store) error {
validateLoginAttempts = func(context.Context, *models.LoginUserQuery, loginattempt.Service) error {
sc.loginAttemptValidationWasCalled = true
return err
}
}
func mockSaveInvalidLoginAttempt(sc *authScenarioContext) {
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, _ sqlstore.Store) error {
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, _ loginattempt.Service) error {
sc.saveInvalidLoginAttemptWasCalled = true
return nil
}

View File

@ -5,7 +5,7 @@ import (
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/loginattempt"
)
var (
@ -13,7 +13,7 @@ var (
loginAttemptsWindow = time.Minute * 5
)
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery, store sqlstore.Store) error {
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery, loginAttemptService loginattempt.Service) error {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
@ -23,7 +23,7 @@ var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQue
Since: time.Now().Add(-loginAttemptsWindow),
}
if err := store.GetUserLoginAttemptCount(ctx, &loginAttemptCountQuery); err != nil {
if err := loginAttemptService.GetUserLoginAttemptCount(ctx, &loginAttemptCountQuery); err != nil {
return err
}
@ -34,7 +34,7 @@ var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQue
return nil
}
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, store sqlstore.Store) error {
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, loginAttemptService loginattempt.Service) error {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
@ -44,5 +44,5 @@ var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQ
IpAddress: query.IpAddress,
}
return store.CreateLoginAttempt(ctx, &loginAttemptCommand)
return loginAttemptService.CreateLoginAttempt(ctx, &loginAttemptCommand)
}

View File

@ -63,7 +63,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
sc.userService.ExpectedUser = &user.User{Password: encoded, ID: id, Salt: salt}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id}
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, sc.userService)
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, nil, sc.userService)
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()

View File

@ -13,6 +13,7 @@ import (
"github.com/grafana/grafana/pkg/infra/usagestats/statscollector"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/api"
_ "github.com/grafana/grafana/pkg/extensions"
@ -43,10 +44,10 @@ type Options struct {
func New(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistry accesscontrol.RoleRegistry,
provisioningService provisioning.ProvisioningService, backgroundServiceProvider registry.BackgroundServiceRegistry,
usageStatsProvidersRegistry registry.UsageStatsProvidersRegistry, statsCollectorService *statscollector.Service,
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service,
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service, loginAttemptService loginattempt.Service,
) (*Server, error) {
statsCollectorService.RegisterProviders(usageStatsProvidersRegistry.GetServices())
s, err := newServer(opts, cfg, httpServer, roleRegistry, provisioningService, backgroundServiceProvider, secretMigrationService, userService)
s, err := newServer(opts, cfg, httpServer, roleRegistry, provisioningService, backgroundServiceProvider, secretMigrationService, userService, loginAttemptService)
if err != nil {
return nil, err
}
@ -60,7 +61,7 @@ func New(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistr
func newServer(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistry accesscontrol.RoleRegistry,
provisioningService provisioning.ProvisioningService, backgroundServiceProvider registry.BackgroundServiceRegistry,
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service,
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service, loginAttemptService loginattempt.Service,
) (*Server, error) {
rootCtx, shutdownFn := context.WithCancel(context.Background())
childRoutines, childCtx := errgroup.WithContext(rootCtx)
@ -82,6 +83,7 @@ func newServer(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleR
backgroundServices: backgroundServiceProvider.GetServices(),
secretMigrationService: secretMigrationService,
userService: userService,
loginAttemptService: loginAttemptService,
}
return s, nil
@ -110,6 +112,7 @@ type Server struct {
provisioningService provisioning.ProvisioningService
secretMigrationService secretsMigrations.SecretMigrationService
userService user.Service
loginAttemptService loginattempt.Service
}
// init initializes the server and its services.
@ -127,7 +130,7 @@ func (s *Server) init() error {
return err
}
login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.userService)
login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService)
social.ProvideService(s.cfg)
if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil {

View File

@ -55,7 +55,7 @@ func testServer(t *testing.T, services ...registry.BackgroundService) *Server {
secretMigrationService := &migrations.SecretMigrationServiceImpl{
ServerLockService: serverLockService,
}
s, err := newServer(Options{}, setting.NewCfg(), nil, &ossaccesscontrol.Service{}, nil, backgroundsvcs.NewBackgroundServiceRegistry(services...), secretMigrationService, usertest.NewUserServiceFake())
s, err := newServer(Options{}, setting.NewCfg(), nil, &ossaccesscontrol.Service{}, nil, backgroundsvcs.NewBackgroundServiceRegistry(services...), secretMigrationService, usertest.NewUserServiceFake(), nil)
require.NoError(t, err)
// Required to skip configuration initialization that causes
// DI errors in this test.

View File

@ -80,7 +80,7 @@ import (
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
authinfodatabase "github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/login_attempt/loginattemptimpl"
"github.com/grafana/grafana/pkg/services/loginattempt/loginattemptimpl"
"github.com/grafana/grafana/pkg/services/ngalert"
ngimage "github.com/grafana/grafana/pkg/services/ngalert/image"
ngmetrics "github.com/grafana/grafana/pkg/services/ngalert/metrics"

View File

@ -14,6 +14,7 @@ import (
"github.com/grafana/grafana/pkg/services/annotations"
"github.com/grafana/grafana/pkg/services/dashboardsnapshots"
dashver "github.com/grafana/grafana/pkg/services/dashboardversion"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/ngalert/image"
"github.com/grafana/grafana/pkg/services/queryhistory"
"github.com/grafana/grafana/pkg/services/shorturls"
@ -23,7 +24,8 @@ import (
func ProvideService(cfg *setting.Cfg, serverLockService *serverlock.ServerLockService,
shortURLService shorturls.Service, sqlstore *sqlstore.SQLStore, queryHistoryService queryhistory.Service,
dashboardVersionService dashver.Service, dashSnapSvc dashboardsnapshots.Service, deleteExpiredImageService *image.DeleteExpiredService) *CleanUpService {
dashboardVersionService dashver.Service, dashSnapSvc dashboardsnapshots.Service, deleteExpiredImageService *image.DeleteExpiredService,
loginAttemptService loginattempt.Service) *CleanUpService {
s := &CleanUpService{
Cfg: cfg,
ServerLockService: serverLockService,
@ -34,6 +36,7 @@ func ProvideService(cfg *setting.Cfg, serverLockService *serverlock.ServerLockSe
dashboardVersionService: dashboardVersionService,
dashboardSnapshotService: dashSnapSvc,
deleteExpiredImageService: deleteExpiredImageService,
loginAttemptService: loginAttemptService,
}
return s
}
@ -48,6 +51,7 @@ type CleanUpService struct {
dashboardVersionService dashver.Service
dashboardSnapshotService dashboardsnapshots.Service
deleteExpiredImageService *image.DeleteExpiredService
loginAttemptService loginattempt.Service
}
func (srv *CleanUpService) Run(ctx context.Context) error {
@ -184,7 +188,7 @@ func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: time.Now().Add(time.Minute * -10),
}
if err := srv.store.DeleteOldLoginAttempts(ctx, &cmd); err != nil {
if err := srv.loginAttemptService.DeleteOldLoginAttempts(ctx, &cmd); err != nil {
srv.log.Error("Problem deleting expired login attempts", "error", err.Error())
} else {
srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows)

View File

@ -1 +0,0 @@
package loginattemptimpl

View File

@ -1 +0,0 @@
package loginattemptimpl

View File

@ -2,27 +2,25 @@ package loginattemptimpl
import (
"context"
"time"
"github.com/grafana/grafana/pkg/models"
loginattempt "github.com/grafana/grafana/pkg/services/login_attempt"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/sqlstore/db"
)
type Service struct {
// TODO remove sqlstore
sqlStore *sqlstore.SQLStore
store store
}
func ProvideService(
ss *sqlstore.SQLStore,
) loginattempt.Service {
func ProvideService(db db.DB) loginattempt.Service {
return &Service{
sqlStore: ss,
store: &xormStore{db: db, now: time.Now},
}
}
func (s *Service) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
err := s.sqlStore.CreateLoginAttempt(ctx, cmd)
err := s.store.CreateLoginAttempt(ctx, cmd)
if err != nil {
return err
}
@ -30,7 +28,7 @@ func (s *Service) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLogi
}
func (s *Service) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
err := s.sqlStore.DeleteOldLoginAttempts(ctx, cmd)
err := s.store.DeleteOldLoginAttempts(ctx, cmd)
if err != nil {
return err
}
@ -38,7 +36,7 @@ func (s *Service) DeleteOldLoginAttempts(ctx context.Context, cmd *models.Delete
}
func (s *Service) GetUserLoginAttemptCount(ctx context.Context, cmd *models.GetUserLoginAttemptCountQuery) error {
err := s.sqlStore.GetUserLoginAttemptCount(ctx, cmd)
err := s.store.GetUserLoginAttemptCount(ctx, cmd)
if err != nil {
return err
}

View File

@ -0,0 +1,101 @@
package loginattemptimpl
import (
"context"
"strconv"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/db"
)
type xormStore struct {
db db.DB
now func() time.Time
}
type store interface {
CreateLoginAttempt(context.Context, *models.CreateLoginAttemptCommand) error
DeleteOldLoginAttempts(context.Context, *models.DeleteOldLoginAttemptsCommand) error
GetUserLoginAttemptCount(context.Context, *models.GetUserLoginAttemptCountQuery) error
}
func (xs *xormStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return xs.db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
loginAttempt := models.LoginAttempt{
Username: cmd.Username,
IpAddress: cmd.IpAddress,
Created: xs.now().Unix(),
}
if _, err := sess.Insert(&loginAttempt); err != nil {
return err
}
cmd.Result = loginAttempt
return nil
})
}
func (xs *xormStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return xs.db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
var maxId int64
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
result, err := sess.Query(sql, cmd.OlderThan.Unix())
if err != nil {
return err
}
if len(result) == 0 || result[0] == nil {
return nil
}
// TODO: why don't we know the type of ID?
maxId = toInt64(result[0]["id"])
if maxId == 0 {
return nil
}
sql = "DELETE FROM login_attempt WHERE id <= ?"
if result, err := sess.Exec(sql, maxId); err != nil {
return err
} else if cmd.DeletedRows, err = result.RowsAffected(); err != nil {
return err
}
return nil
})
}
func (xs *xormStore) GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
return xs.db.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
loginAttempt := new(models.LoginAttempt)
total, err := dbSession.
Where("username = ?", query.Username).
And("created >= ?", query.Since.Unix()).
Count(loginAttempt)
if err != nil {
return err
}
query.Result = total
return nil
})
}
func toInt64(i interface{}) int64 {
switch i := i.(type) {
case []byte:
n, _ := strconv.ParseInt(string(i), 10, 64)
return n
case int:
return int64(i)
case int64:
return i
}
return 0
}

View File

@ -0,0 +1,144 @@
package loginattemptimpl
import (
"context"
"testing"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/loginattempt"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/stretchr/testify/require"
)
func TestIntegrationLoginAttemptsQuery(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
var loginAttemptService loginattempt.Service
user := "user"
beginningOfTime := time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)
timePlusOneMinute := beginningOfTime.Add(time.Minute * 1)
timePlusTwoMinutes := beginningOfTime.Add(time.Minute * 2)
for _, test := range []struct {
Name string
Query models.GetUserLoginAttemptCountQuery
Err error
Result int64
}{
{
"Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s",
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 0,
},
{
"Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s",
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 0,
},
{
"Should return the total count of login attempts since beginning of time",
models.GetUserLoginAttemptCountQuery{Username: user, Since: beginningOfTime}, nil, 3,
},
{
"Should return the total count of login attempts since beginning of time + 1min",
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusOneMinute}, nil, 2,
},
{
"Should return the total count of login attempts since beginning of time + 2min",
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes}, nil, 1,
},
} {
mockTime := beginningOfTime
loginAttemptService = &Service{
store: &xormStore{
db: sqlstore.InitTestDB(t),
now: func() time.Time { return mockTime },
},
}
err := loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
mockTime = timePlusOneMinute
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
mockTime = timePlusTwoMinutes
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
err = loginAttemptService.GetUserLoginAttemptCount(context.Background(), &test.Query)
require.Equal(t, test.Err, err, test.Name)
require.Equal(t, test.Result, test.Query.Result, test.Name)
}
}
func TestIntegrationLoginAttemptsDelete(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
var loginAttemptService loginattempt.Service
user := "user"
beginningOfTime := time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)
timePlusOneMinute := beginningOfTime.Add(time.Minute * 1)
timePlusTwoMinutes := beginningOfTime.Add(time.Minute * 2)
for _, test := range []struct {
Name string
Cmd models.DeleteOldLoginAttemptsCommand
Err error
DeletedRows int64
}{
{
"Should return deleted rows older than beginning of time",
models.DeleteOldLoginAttemptsCommand{OlderThan: beginningOfTime}, nil, 0,
},
{
"Should return deleted rows older than beginning of time + 1min",
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusOneMinute}, nil, 1,
},
{
"Should return deleted rows older than beginning of time + 2min",
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusTwoMinutes}, nil, 2,
},
{
"Should return deleted rows older than beginning of time + 2min and 1s",
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 3,
},
} {
mockTime := beginningOfTime
loginAttemptService = &Service{
store: &xormStore{
db: sqlstore.InitTestDB(t),
now: func() time.Time { return mockTime },
},
}
err := loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
mockTime = timePlusOneMinute
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
mockTime = timePlusTwoMinutes
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
err = loginAttemptService.DeleteOldLoginAttempts(context.Background(), &test.Cmd)
require.Equal(t, test.Err, err, test.Name)
require.Equal(t, test.DeletedRows, test.Cmd.DeletedRows, test.Name)
}
}

View File

@ -1,89 +1 @@
package sqlstore
import (
"context"
"strconv"
"time"
"github.com/grafana/grafana/pkg/models"
)
var getTimeNow = time.Now
func (ss *SQLStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
loginAttempt := models.LoginAttempt{
Username: cmd.Username,
IpAddress: cmd.IpAddress,
Created: getTimeNow().Unix(),
}
if _, err := sess.Insert(&loginAttempt); err != nil {
return err
}
cmd.Result = loginAttempt
return nil
})
}
func (ss *SQLStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var maxId int64
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
result, err := sess.Query(sql, cmd.OlderThan.Unix())
if err != nil {
return err
}
if len(result) == 0 || result[0] == nil {
return nil
}
maxId = toInt64(result[0]["id"])
if maxId == 0 {
return nil
}
sql = "DELETE FROM login_attempt WHERE id <= ?"
if result, err := sess.Exec(sql, maxId); err != nil {
return err
} else if cmd.DeletedRows, err = result.RowsAffected(); err != nil {
return err
}
return nil
})
}
func (ss *SQLStore) GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
return ss.WithDbSession(ctx, func(dbSession *DBSession) error {
loginAttempt := new(models.LoginAttempt)
total, err := dbSession.
Where("username = ?", query.Username).
And("created >= ?", query.Since.Unix()).
Count(loginAttempt)
if err != nil {
return err
}
query.Result = total
return nil
})
}
func toInt64(i interface{}) int64 {
switch i := i.(type) {
case []byte:
n, _ := strconv.ParseInt(string(i), 10, 64)
return n
case int:
return int64(i)
case int64:
return i
}
return 0
}

View File

@ -1,135 +1 @@
package sqlstore
import (
"context"
"testing"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/stretchr/testify/require"
)
func mockTime(mock time.Time) time.Time {
getTimeNow = func() time.Time { return mock }
return mock
}
func TestIntegrationLoginAttempts(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
var beginningOfTime, timePlusOneMinute, timePlusTwoMinutes time.Time
var sqlStore *SQLStore
user := "user"
setup := func(t *testing.T) {
sqlStore = InitTestDB(t)
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
err := sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
}
t.Run("Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s", func(t *testing.T) {
setup(t)
query := models.GetUserLoginAttemptCountQuery{
Username: user,
Since: timePlusTwoMinutes.Add(time.Second * 1),
}
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(0), query.Result)
})
t.Run("Should return the total count of login attempts since beginning of time", func(t *testing.T) {
setup(t)
query := models.GetUserLoginAttemptCountQuery{
Username: user,
Since: beginningOfTime,
}
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(3), query.Result)
})
t.Run("Should return the total count of login attempts since beginning of time + 1min", func(t *testing.T) {
setup(t)
query := models.GetUserLoginAttemptCountQuery{
Username: user,
Since: timePlusOneMinute,
}
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(2), query.Result)
})
t.Run("Should return the total count of login attempts since beginning of time + 2min", func(t *testing.T) {
setup(t)
query := models.GetUserLoginAttemptCountQuery{
Username: user,
Since: timePlusTwoMinutes,
}
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
require.Nil(t, err)
require.Equal(t, int64(1), query.Result)
})
t.Run("Should return deleted rows older than beginning of time", func(t *testing.T) {
setup(t)
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: beginningOfTime,
}
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(0), cmd.DeletedRows)
})
t.Run("Should return deleted rows older than beginning of time + 1min", func(t *testing.T) {
setup(t)
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusOneMinute,
}
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(1), cmd.DeletedRows)
})
t.Run("Should return deleted rows older than beginning of time + 2min", func(t *testing.T) {
setup(t)
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes,
}
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(2), cmd.DeletedRows)
})
t.Run("Should return deleted rows older than beginning of time + 2min and 1s", func(t *testing.T) {
setup(t)
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
}
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(3), cmd.DeletedRows)
})
}

View File

@ -26,9 +26,6 @@ type Store interface {
DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error
GetOrgById(context.Context, *models.GetOrgByIdQuery) error
GetOrgByNameHandler(ctx context.Context, query *models.GetOrgByNameQuery) error
CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error
GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error
DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error
CreateUser(ctx context.Context, cmd user.CreateUserCommand) (*user.User, error)
SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error
GetUserProfile(ctx context.Context, query *models.GetUserProfileQuery) error