mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: Move login attempt methods to separate service (#54479)
* Chore: Move login attempt methods to separate service * attempt to fix tests * fix syntax * better time mocking * initialise now func
This commit is contained in:
parent
d2bdb01092
commit
927ddf9376
@ -59,7 +59,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/live"
|
||||
"github.com/grafana/grafana/pkg/services/live/pushhttp"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
loginAttempt "github.com/grafana/grafana/pkg/services/login_attempt"
|
||||
loginAttempt "github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/ngalert"
|
||||
"github.com/grafana/grafana/pkg/services/notifications"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
|
@ -77,6 +77,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||
authinfodatabase "github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt/loginattemptimpl"
|
||||
"github.com/grafana/grafana/pkg/services/ngalert"
|
||||
ngmetrics "github.com/grafana/grafana/pkg/services/ngalert/metrics"
|
||||
"github.com/grafana/grafana/pkg/services/notifications"
|
||||
@ -219,6 +220,7 @@ var wireSet = wire.NewSet(
|
||||
authinfodatabase.ProvideAuthInfoStore,
|
||||
loginpkg.ProvideService,
|
||||
wire.Bind(new(loginpkg.Authenticator), new(*loginpkg.AuthenticatorService)),
|
||||
loginattemptimpl.ProvideService,
|
||||
datasourceproxy.ProvideService,
|
||||
search.ProvideService,
|
||||
searchV2.ProvideService,
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/ldap"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
@ -32,23 +33,23 @@ type Authenticator interface {
|
||||
}
|
||||
|
||||
type AuthenticatorService struct {
|
||||
store sqlstore.Store
|
||||
loginService login.Service
|
||||
userService user.Service
|
||||
loginService login.Service
|
||||
loginAttemptService loginattempt.Service
|
||||
userService user.Service
|
||||
}
|
||||
|
||||
func ProvideService(store sqlstore.Store, loginService login.Service, userService user.Service) *AuthenticatorService {
|
||||
func ProvideService(store sqlstore.Store, loginService login.Service, loginAttemptService loginattempt.Service, userService user.Service) *AuthenticatorService {
|
||||
a := &AuthenticatorService{
|
||||
store: store,
|
||||
loginService: loginService,
|
||||
userService: userService,
|
||||
loginService: loginService,
|
||||
loginAttemptService: loginAttemptService,
|
||||
userService: userService,
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
// AuthenticateUser authenticates the user via username & password
|
||||
func (a *AuthenticatorService) AuthenticateUser(ctx context.Context, query *models.LoginUserQuery) error {
|
||||
if err := validateLoginAttempts(ctx, query, a.store); err != nil {
|
||||
if err := validateLoginAttempts(ctx, query, a.loginAttemptService); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -76,7 +77,7 @@ func (a *AuthenticatorService) AuthenticateUser(ctx context.Context, query *mode
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ldap.ErrInvalidCredentials) {
|
||||
if err := saveInvalidLoginAttempt(ctx, query, a.store); err != nil {
|
||||
if err := saveInvalidLoginAttempt(ctx, query, a.loginAttemptService); err != nil {
|
||||
loginLogger.Error("Failed to save invalid login attempt", "err", err)
|
||||
}
|
||||
|
||||
|
@ -9,8 +9,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/ldap"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -26,7 +25,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
Username: "user",
|
||||
Password: "",
|
||||
}
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginAttemptService: nil, loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), &loginQuery)
|
||||
|
||||
require.EqualError(t, err, ErrPasswordEmpty.Error())
|
||||
@ -41,7 +40,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrTooManyLoginAttempts.Error())
|
||||
@ -58,7 +57,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.NoError(t, err)
|
||||
@ -76,7 +75,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, customErr.Error())
|
||||
@ -93,7 +92,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(false, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, user.ErrUserNotFound.Error())
|
||||
@ -110,7 +109,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrInvalidCredentials.Error())
|
||||
@ -127,7 +126,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, nil, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.NoError(t, err)
|
||||
@ -145,7 +144,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, customErr, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, customErr.Error())
|
||||
@ -162,7 +161,7 @@ func TestAuthenticateUser(t *testing.T) {
|
||||
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
|
||||
mockSaveInvalidLoginAttempt(sc)
|
||||
|
||||
a := AuthenticatorService{store: mockstore.NewSQLStoreMock(), loginService: &logintest.LoginServiceFake{}}
|
||||
a := AuthenticatorService{loginService: &logintest.LoginServiceFake{}}
|
||||
err := a.AuthenticateUser(context.Background(), sc.loginUserQuery)
|
||||
|
||||
require.EqualError(t, err, ErrInvalidCredentials.Error())
|
||||
@ -198,14 +197,14 @@ func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
|
||||
}
|
||||
|
||||
func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
|
||||
validateLoginAttempts = func(context.Context, *models.LoginUserQuery, sqlstore.Store) error {
|
||||
validateLoginAttempts = func(context.Context, *models.LoginUserQuery, loginattempt.Service) error {
|
||||
sc.loginAttemptValidationWasCalled = true
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func mockSaveInvalidLoginAttempt(sc *authScenarioContext) {
|
||||
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, _ sqlstore.Store) error {
|
||||
saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, _ loginattempt.Service) error {
|
||||
sc.saveInvalidLoginAttemptWasCalled = true
|
||||
return nil
|
||||
}
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -13,7 +13,7 @@ var (
|
||||
loginAttemptsWindow = time.Minute * 5
|
||||
)
|
||||
|
||||
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery, store sqlstore.Store) error {
|
||||
var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQuery, loginAttemptService loginattempt.Service) error {
|
||||
if query.Cfg.DisableBruteForceLoginProtection {
|
||||
return nil
|
||||
}
|
||||
@ -23,7 +23,7 @@ var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQue
|
||||
Since: time.Now().Add(-loginAttemptsWindow),
|
||||
}
|
||||
|
||||
if err := store.GetUserLoginAttemptCount(ctx, &loginAttemptCountQuery); err != nil {
|
||||
if err := loginAttemptService.GetUserLoginAttemptCount(ctx, &loginAttemptCountQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ var validateLoginAttempts = func(ctx context.Context, query *models.LoginUserQue
|
||||
return nil
|
||||
}
|
||||
|
||||
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, store sqlstore.Store) error {
|
||||
var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQuery, loginAttemptService loginattempt.Service) error {
|
||||
if query.Cfg.DisableBruteForceLoginProtection {
|
||||
return nil
|
||||
}
|
||||
@ -44,5 +44,5 @@ var saveInvalidLoginAttempt = func(ctx context.Context, query *models.LoginUserQ
|
||||
IpAddress: query.IpAddress,
|
||||
}
|
||||
|
||||
return store.CreateLoginAttempt(ctx, &loginAttemptCommand)
|
||||
return loginAttemptService.CreateLoginAttempt(ctx, &loginAttemptCommand)
|
||||
}
|
||||
|
@ -63,7 +63,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
|
||||
|
||||
sc.userService.ExpectedUser = &user.User{Password: encoded, ID: id, Salt: salt}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id}
|
||||
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, sc.userService)
|
||||
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, nil, sc.userService)
|
||||
|
||||
authHeader := util.GetBasicAuthHeader("myUser", password)
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats/statscollector"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api"
|
||||
_ "github.com/grafana/grafana/pkg/extensions"
|
||||
@ -43,10 +44,10 @@ type Options struct {
|
||||
func New(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistry accesscontrol.RoleRegistry,
|
||||
provisioningService provisioning.ProvisioningService, backgroundServiceProvider registry.BackgroundServiceRegistry,
|
||||
usageStatsProvidersRegistry registry.UsageStatsProvidersRegistry, statsCollectorService *statscollector.Service,
|
||||
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service,
|
||||
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service, loginAttemptService loginattempt.Service,
|
||||
) (*Server, error) {
|
||||
statsCollectorService.RegisterProviders(usageStatsProvidersRegistry.GetServices())
|
||||
s, err := newServer(opts, cfg, httpServer, roleRegistry, provisioningService, backgroundServiceProvider, secretMigrationService, userService)
|
||||
s, err := newServer(opts, cfg, httpServer, roleRegistry, provisioningService, backgroundServiceProvider, secretMigrationService, userService, loginAttemptService)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -60,7 +61,7 @@ func New(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistr
|
||||
|
||||
func newServer(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleRegistry accesscontrol.RoleRegistry,
|
||||
provisioningService provisioning.ProvisioningService, backgroundServiceProvider registry.BackgroundServiceRegistry,
|
||||
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service,
|
||||
secretMigrationService secretsMigrations.SecretMigrationService, userService user.Service, loginAttemptService loginattempt.Service,
|
||||
) (*Server, error) {
|
||||
rootCtx, shutdownFn := context.WithCancel(context.Background())
|
||||
childRoutines, childCtx := errgroup.WithContext(rootCtx)
|
||||
@ -82,6 +83,7 @@ func newServer(opts Options, cfg *setting.Cfg, httpServer *api.HTTPServer, roleR
|
||||
backgroundServices: backgroundServiceProvider.GetServices(),
|
||||
secretMigrationService: secretMigrationService,
|
||||
userService: userService,
|
||||
loginAttemptService: loginAttemptService,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@ -110,6 +112,7 @@ type Server struct {
|
||||
provisioningService provisioning.ProvisioningService
|
||||
secretMigrationService secretsMigrations.SecretMigrationService
|
||||
userService user.Service
|
||||
loginAttemptService loginattempt.Service
|
||||
}
|
||||
|
||||
// init initializes the server and its services.
|
||||
@ -127,7 +130,7 @@ func (s *Server) init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.userService)
|
||||
login.ProvideService(s.HTTPServer.SQLStore, s.HTTPServer.Login, s.loginAttemptService, s.userService)
|
||||
social.ProvideService(s.cfg)
|
||||
|
||||
if err := s.roleRegistry.RegisterFixedRoles(s.context); err != nil {
|
||||
|
@ -55,7 +55,7 @@ func testServer(t *testing.T, services ...registry.BackgroundService) *Server {
|
||||
secretMigrationService := &migrations.SecretMigrationServiceImpl{
|
||||
ServerLockService: serverLockService,
|
||||
}
|
||||
s, err := newServer(Options{}, setting.NewCfg(), nil, &ossaccesscontrol.Service{}, nil, backgroundsvcs.NewBackgroundServiceRegistry(services...), secretMigrationService, usertest.NewUserServiceFake())
|
||||
s, err := newServer(Options{}, setting.NewCfg(), nil, &ossaccesscontrol.Service{}, nil, backgroundsvcs.NewBackgroundServiceRegistry(services...), secretMigrationService, usertest.NewUserServiceFake(), nil)
|
||||
require.NoError(t, err)
|
||||
// Required to skip configuration initialization that causes
|
||||
// DI errors in this test.
|
||||
|
@ -80,7 +80,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||
authinfodatabase "github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/login_attempt/loginattemptimpl"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt/loginattemptimpl"
|
||||
"github.com/grafana/grafana/pkg/services/ngalert"
|
||||
ngimage "github.com/grafana/grafana/pkg/services/ngalert/image"
|
||||
ngmetrics "github.com/grafana/grafana/pkg/services/ngalert/metrics"
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/annotations"
|
||||
"github.com/grafana/grafana/pkg/services/dashboardsnapshots"
|
||||
dashver "github.com/grafana/grafana/pkg/services/dashboardversion"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/ngalert/image"
|
||||
"github.com/grafana/grafana/pkg/services/queryhistory"
|
||||
"github.com/grafana/grafana/pkg/services/shorturls"
|
||||
@ -23,7 +24,8 @@ import (
|
||||
|
||||
func ProvideService(cfg *setting.Cfg, serverLockService *serverlock.ServerLockService,
|
||||
shortURLService shorturls.Service, sqlstore *sqlstore.SQLStore, queryHistoryService queryhistory.Service,
|
||||
dashboardVersionService dashver.Service, dashSnapSvc dashboardsnapshots.Service, deleteExpiredImageService *image.DeleteExpiredService) *CleanUpService {
|
||||
dashboardVersionService dashver.Service, dashSnapSvc dashboardsnapshots.Service, deleteExpiredImageService *image.DeleteExpiredService,
|
||||
loginAttemptService loginattempt.Service) *CleanUpService {
|
||||
s := &CleanUpService{
|
||||
Cfg: cfg,
|
||||
ServerLockService: serverLockService,
|
||||
@ -34,6 +36,7 @@ func ProvideService(cfg *setting.Cfg, serverLockService *serverlock.ServerLockSe
|
||||
dashboardVersionService: dashboardVersionService,
|
||||
dashboardSnapshotService: dashSnapSvc,
|
||||
deleteExpiredImageService: deleteExpiredImageService,
|
||||
loginAttemptService: loginAttemptService,
|
||||
}
|
||||
return s
|
||||
}
|
||||
@ -48,6 +51,7 @@ type CleanUpService struct {
|
||||
dashboardVersionService dashver.Service
|
||||
dashboardSnapshotService dashboardsnapshots.Service
|
||||
deleteExpiredImageService *image.DeleteExpiredService
|
||||
loginAttemptService loginattempt.Service
|
||||
}
|
||||
|
||||
func (srv *CleanUpService) Run(ctx context.Context) error {
|
||||
@ -184,7 +188,7 @@ func (srv *CleanUpService) deleteOldLoginAttempts(ctx context.Context) {
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: time.Now().Add(time.Minute * -10),
|
||||
}
|
||||
if err := srv.store.DeleteOldLoginAttempts(ctx, &cmd); err != nil {
|
||||
if err := srv.loginAttemptService.DeleteOldLoginAttempts(ctx, &cmd); err != nil {
|
||||
srv.log.Error("Problem deleting expired login attempts", "error", err.Error())
|
||||
} else {
|
||||
srv.log.Debug("Deleted expired login attempts", "rows affected", cmd.DeletedRows)
|
||||
|
@ -1 +0,0 @@
|
||||
package loginattemptimpl
|
@ -1 +0,0 @@
|
||||
package loginattemptimpl
|
@ -2,27 +2,25 @@ package loginattemptimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
loginattempt "github.com/grafana/grafana/pkg/services/login_attempt"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/db"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
// TODO remove sqlstore
|
||||
sqlStore *sqlstore.SQLStore
|
||||
store store
|
||||
}
|
||||
|
||||
func ProvideService(
|
||||
ss *sqlstore.SQLStore,
|
||||
) loginattempt.Service {
|
||||
func ProvideService(db db.DB) loginattempt.Service {
|
||||
return &Service{
|
||||
sqlStore: ss,
|
||||
store: &xormStore{db: db, now: time.Now},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
|
||||
err := s.sqlStore.CreateLoginAttempt(ctx, cmd)
|
||||
err := s.store.CreateLoginAttempt(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -30,7 +28,7 @@ func (s *Service) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLogi
|
||||
}
|
||||
|
||||
func (s *Service) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
err := s.sqlStore.DeleteOldLoginAttempts(ctx, cmd)
|
||||
err := s.store.DeleteOldLoginAttempts(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -38,7 +36,7 @@ func (s *Service) DeleteOldLoginAttempts(ctx context.Context, cmd *models.Delete
|
||||
}
|
||||
|
||||
func (s *Service) GetUserLoginAttemptCount(ctx context.Context, cmd *models.GetUserLoginAttemptCountQuery) error {
|
||||
err := s.sqlStore.GetUserLoginAttemptCount(ctx, cmd)
|
||||
err := s.store.GetUserLoginAttemptCount(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
101
pkg/services/loginattempt/loginattemptimpl/store.go
Normal file
101
pkg/services/loginattempt/loginattemptimpl/store.go
Normal file
@ -0,0 +1,101 @@
|
||||
package loginattemptimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/db"
|
||||
)
|
||||
|
||||
type xormStore struct {
|
||||
db db.DB
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type store interface {
|
||||
CreateLoginAttempt(context.Context, *models.CreateLoginAttemptCommand) error
|
||||
DeleteOldLoginAttempts(context.Context, *models.DeleteOldLoginAttemptsCommand) error
|
||||
GetUserLoginAttemptCount(context.Context, *models.GetUserLoginAttemptCountQuery) error
|
||||
}
|
||||
|
||||
func (xs *xormStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
|
||||
return xs.db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||
loginAttempt := models.LoginAttempt{
|
||||
Username: cmd.Username,
|
||||
IpAddress: cmd.IpAddress,
|
||||
Created: xs.now().Unix(),
|
||||
}
|
||||
|
||||
if _, err := sess.Insert(&loginAttempt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Result = loginAttempt
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (xs *xormStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
return xs.db.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||
var maxId int64
|
||||
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
|
||||
result, err := sess.Query(sql, cmd.OlderThan.Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(result) == 0 || result[0] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: why don't we know the type of ID?
|
||||
maxId = toInt64(result[0]["id"])
|
||||
|
||||
if maxId == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sql = "DELETE FROM login_attempt WHERE id <= ?"
|
||||
|
||||
if result, err := sess.Exec(sql, maxId); err != nil {
|
||||
return err
|
||||
} else if cmd.DeletedRows, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (xs *xormStore) GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
|
||||
return xs.db.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
|
||||
loginAttempt := new(models.LoginAttempt)
|
||||
total, err := dbSession.
|
||||
Where("username = ?", query.Username).
|
||||
And("created >= ?", query.Since.Unix()).
|
||||
Count(loginAttempt)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query.Result = total
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func toInt64(i interface{}) int64 {
|
||||
switch i := i.(type) {
|
||||
case []byte:
|
||||
n, _ := strconv.ParseInt(string(i), 10, 64)
|
||||
return n
|
||||
case int:
|
||||
return int64(i)
|
||||
case int64:
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
144
pkg/services/loginattempt/loginattemptimpl/store_test.go
Normal file
144
pkg/services/loginattempt/loginattemptimpl/store_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package loginattemptimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntegrationLoginAttemptsQuery(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
var loginAttemptService loginattempt.Service
|
||||
user := "user"
|
||||
|
||||
beginningOfTime := time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)
|
||||
timePlusOneMinute := beginningOfTime.Add(time.Minute * 1)
|
||||
timePlusTwoMinutes := beginningOfTime.Add(time.Minute * 2)
|
||||
|
||||
for _, test := range []struct {
|
||||
Name string
|
||||
Query models.GetUserLoginAttemptCountQuery
|
||||
Err error
|
||||
Result int64
|
||||
}{
|
||||
{
|
||||
"Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s",
|
||||
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 0,
|
||||
},
|
||||
{
|
||||
"Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s",
|
||||
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 0,
|
||||
},
|
||||
{
|
||||
"Should return the total count of login attempts since beginning of time",
|
||||
models.GetUserLoginAttemptCountQuery{Username: user, Since: beginningOfTime}, nil, 3,
|
||||
},
|
||||
{
|
||||
"Should return the total count of login attempts since beginning of time + 1min",
|
||||
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusOneMinute}, nil, 2,
|
||||
},
|
||||
{
|
||||
"Should return the total count of login attempts since beginning of time + 2min",
|
||||
models.GetUserLoginAttemptCountQuery{Username: user, Since: timePlusTwoMinutes}, nil, 1,
|
||||
},
|
||||
} {
|
||||
mockTime := beginningOfTime
|
||||
loginAttemptService = &Service{
|
||||
store: &xormStore{
|
||||
db: sqlstore.InitTestDB(t),
|
||||
now: func() time.Time { return mockTime },
|
||||
},
|
||||
}
|
||||
err := loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
mockTime = timePlusOneMinute
|
||||
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
mockTime = timePlusTwoMinutes
|
||||
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
err = loginAttemptService.GetUserLoginAttemptCount(context.Background(), &test.Query)
|
||||
require.Equal(t, test.Err, err, test.Name)
|
||||
require.Equal(t, test.Result, test.Query.Result, test.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationLoginAttemptsDelete(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
var loginAttemptService loginattempt.Service
|
||||
user := "user"
|
||||
|
||||
beginningOfTime := time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local)
|
||||
timePlusOneMinute := beginningOfTime.Add(time.Minute * 1)
|
||||
timePlusTwoMinutes := beginningOfTime.Add(time.Minute * 2)
|
||||
|
||||
for _, test := range []struct {
|
||||
Name string
|
||||
Cmd models.DeleteOldLoginAttemptsCommand
|
||||
Err error
|
||||
DeletedRows int64
|
||||
}{
|
||||
{
|
||||
"Should return deleted rows older than beginning of time",
|
||||
models.DeleteOldLoginAttemptsCommand{OlderThan: beginningOfTime}, nil, 0,
|
||||
},
|
||||
{
|
||||
"Should return deleted rows older than beginning of time + 1min",
|
||||
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusOneMinute}, nil, 1,
|
||||
},
|
||||
{
|
||||
"Should return deleted rows older than beginning of time + 2min",
|
||||
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusTwoMinutes}, nil, 2,
|
||||
},
|
||||
{
|
||||
"Should return deleted rows older than beginning of time + 2min and 1s",
|
||||
models.DeleteOldLoginAttemptsCommand{OlderThan: timePlusTwoMinutes.Add(time.Second * 1)}, nil, 3,
|
||||
},
|
||||
} {
|
||||
mockTime := beginningOfTime
|
||||
loginAttemptService = &Service{
|
||||
store: &xormStore{
|
||||
db: sqlstore.InitTestDB(t),
|
||||
now: func() time.Time { return mockTime },
|
||||
},
|
||||
}
|
||||
err := loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
mockTime = timePlusOneMinute
|
||||
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
mockTime = timePlusTwoMinutes
|
||||
err = loginAttemptService.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
err = loginAttemptService.DeleteOldLoginAttempts(context.Background(), &test.Cmd)
|
||||
require.Equal(t, test.Err, err, test.Name)
|
||||
require.Equal(t, test.DeletedRows, test.Cmd.DeletedRows, test.Name)
|
||||
}
|
||||
}
|
@ -1,89 +1 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
)
|
||||
|
||||
var getTimeNow = time.Now
|
||||
|
||||
func (ss *SQLStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
|
||||
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
|
||||
loginAttempt := models.LoginAttempt{
|
||||
Username: cmd.Username,
|
||||
IpAddress: cmd.IpAddress,
|
||||
Created: getTimeNow().Unix(),
|
||||
}
|
||||
|
||||
if _, err := sess.Insert(&loginAttempt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Result = loginAttempt
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (ss *SQLStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
|
||||
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
|
||||
var maxId int64
|
||||
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
|
||||
result, err := sess.Query(sql, cmd.OlderThan.Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(result) == 0 || result[0] == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxId = toInt64(result[0]["id"])
|
||||
|
||||
if maxId == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sql = "DELETE FROM login_attempt WHERE id <= ?"
|
||||
|
||||
if result, err := sess.Exec(sql, maxId); err != nil {
|
||||
return err
|
||||
} else if cmd.DeletedRows, err = result.RowsAffected(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (ss *SQLStore) GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error {
|
||||
return ss.WithDbSession(ctx, func(dbSession *DBSession) error {
|
||||
loginAttempt := new(models.LoginAttempt)
|
||||
total, err := dbSession.
|
||||
Where("username = ?", query.Username).
|
||||
And("created >= ?", query.Since.Unix()).
|
||||
Count(loginAttempt)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query.Result = total
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func toInt64(i interface{}) int64 {
|
||||
switch i := i.(type) {
|
||||
case []byte:
|
||||
n, _ := strconv.ParseInt(string(i), 10, 64)
|
||||
return n
|
||||
case int:
|
||||
return int64(i)
|
||||
case int64:
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
@ -1,135 +1 @@
|
||||
package sqlstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mockTime(mock time.Time) time.Time {
|
||||
getTimeNow = func() time.Time { return mock }
|
||||
return mock
|
||||
}
|
||||
|
||||
func TestIntegrationLoginAttempts(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test")
|
||||
}
|
||||
var beginningOfTime, timePlusOneMinute, timePlusTwoMinutes time.Time
|
||||
var sqlStore *SQLStore
|
||||
user := "user"
|
||||
|
||||
setup := func(t *testing.T) {
|
||||
sqlStore = InitTestDB(t)
|
||||
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
|
||||
err := sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
|
||||
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
|
||||
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
|
||||
Username: user,
|
||||
IpAddress: "192.168.0.1",
|
||||
})
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
t.Run("Should return a total count of zero login attempts when comparing since beginning of time + 2min and 1s", func(t *testing.T) {
|
||||
setup(t)
|
||||
query := models.GetUserLoginAttemptCountQuery{
|
||||
Username: user,
|
||||
Since: timePlusTwoMinutes.Add(time.Second * 1),
|
||||
}
|
||||
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), query.Result)
|
||||
})
|
||||
|
||||
t.Run("Should return the total count of login attempts since beginning of time", func(t *testing.T) {
|
||||
setup(t)
|
||||
query := models.GetUserLoginAttemptCountQuery{
|
||||
Username: user,
|
||||
Since: beginningOfTime,
|
||||
}
|
||||
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(3), query.Result)
|
||||
})
|
||||
|
||||
t.Run("Should return the total count of login attempts since beginning of time + 1min", func(t *testing.T) {
|
||||
setup(t)
|
||||
query := models.GetUserLoginAttemptCountQuery{
|
||||
Username: user,
|
||||
Since: timePlusOneMinute,
|
||||
}
|
||||
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), query.Result)
|
||||
})
|
||||
|
||||
t.Run("Should return the total count of login attempts since beginning of time + 2min", func(t *testing.T) {
|
||||
setup(t)
|
||||
query := models.GetUserLoginAttemptCountQuery{
|
||||
Username: user,
|
||||
Since: timePlusTwoMinutes,
|
||||
}
|
||||
err := sqlStore.GetUserLoginAttemptCount(context.Background(), &query)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1), query.Result)
|
||||
})
|
||||
|
||||
t.Run("Should return deleted rows older than beginning of time", func(t *testing.T) {
|
||||
setup(t)
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: beginningOfTime,
|
||||
}
|
||||
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(0), cmd.DeletedRows)
|
||||
})
|
||||
|
||||
t.Run("Should return deleted rows older than beginning of time + 1min", func(t *testing.T) {
|
||||
setup(t)
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusOneMinute,
|
||||
}
|
||||
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1), cmd.DeletedRows)
|
||||
})
|
||||
|
||||
t.Run("Should return deleted rows older than beginning of time + 2min", func(t *testing.T) {
|
||||
setup(t)
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusTwoMinutes,
|
||||
}
|
||||
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(2), cmd.DeletedRows)
|
||||
})
|
||||
|
||||
t.Run("Should return deleted rows older than beginning of time + 2min and 1s", func(t *testing.T) {
|
||||
setup(t)
|
||||
cmd := models.DeleteOldLoginAttemptsCommand{
|
||||
OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
|
||||
}
|
||||
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(3), cmd.DeletedRows)
|
||||
})
|
||||
}
|
||||
|
@ -26,9 +26,6 @@ type Store interface {
|
||||
DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error
|
||||
GetOrgById(context.Context, *models.GetOrgByIdQuery) error
|
||||
GetOrgByNameHandler(ctx context.Context, query *models.GetOrgByNameQuery) error
|
||||
CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error
|
||||
GetUserLoginAttemptCount(ctx context.Context, query *models.GetUserLoginAttemptCountQuery) error
|
||||
DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error
|
||||
CreateUser(ctx context.Context, cmd user.CreateUserCommand) (*user.User, error)
|
||||
SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error
|
||||
GetUserProfile(ctx context.Context, query *models.GetUserProfileQuery) error
|
||||
|
Loading…
Reference in New Issue
Block a user