mirror of
https://github.com/grafana/grafana.git
synced 2024-11-26 02:40:26 -06:00
Login: remove unused function (#78442)
* Move test to the db so we test the queries and not just testing the mock * Remove unused function and dependencies * Remove unused functions from the database * Add some integration tests
This commit is contained in:
parent
e2f2d8b3d6
commit
d42201dbf4
@ -18,7 +18,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
@ -66,9 +65,7 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) {
|
||||
secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore))
|
||||
authInfoStore := authinfostore.ProvideAuthInfoStore(sqlStore, secretsService, userMock)
|
||||
srv := authinfoservice.ProvideAuthInfoService(
|
||||
&authinfoservice.OSSUserProtectionImpl{},
|
||||
authInfoStore,
|
||||
&usagestats.UsageStatsMock{},
|
||||
)
|
||||
hs.authInfoService = srv
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, quotatest.New(false, nil))
|
||||
|
@ -33,7 +33,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
|
||||
userProtection := &authinfoservice.OSSUserProtectionImpl{}
|
||||
|
||||
authFakeNil := &logintest.AuthInfoServiceFake{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: user.ErrUserNotFound,
|
||||
SetAuthInfoFn: func(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
return nil
|
||||
@ -43,7 +42,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
|
||||
},
|
||||
}
|
||||
authFakeUserID := &logintest.AuthInfoServiceFake{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: nil,
|
||||
ExpectedUserAuth: &login.UserAuth{
|
||||
AuthModule: "oauth",
|
||||
@ -68,7 +66,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
|
||||
}}
|
||||
|
||||
userServiceNil := &usertest.FakeUserService{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: user.ErrUserNotFound,
|
||||
CreateFn: func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error) {
|
||||
return &user.User{
|
||||
|
@ -3,12 +3,10 @@ package login
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
type AuthInfoService interface {
|
||||
LookupAndUpdate(ctx context.Context, query *GetUserByAuthInfoQuery) (*user.User, error)
|
||||
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
|
||||
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
|
||||
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
|
||||
@ -21,15 +19,9 @@ type Store interface {
|
||||
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
|
||||
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
|
||||
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error
|
||||
UpdateAuthInfoDate(ctx context.Context, authInfo *UserAuth) error
|
||||
DeleteAuthInfo(ctx context.Context, cmd *DeleteAuthInfoCommand) error
|
||||
DeleteUserAuthInfo(ctx context.Context, userID int64) error
|
||||
GetUserById(ctx context.Context, id int64) (*user.User, error)
|
||||
GetUserByLogin(ctx context.Context, login string) (*user.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*user.User, error)
|
||||
CollectLoginStats(ctx context.Context) (map[string]any, error)
|
||||
RunMetricsCollection(ctx context.Context) error
|
||||
GetLoginStats(ctx context.Context) (LoginStats, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -153,23 +153,6 @@ func (s *AuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoC
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAuthInfoDate updates the auth info for the user with the latest date.
|
||||
// Avoids overlapping entries hiding the last used one (ex: LDAP->SAML->LDAP).
|
||||
func (s *AuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
|
||||
authInfo.Created = GetTime()
|
||||
|
||||
cond := &login.UserAuth{
|
||||
Id: authInfo.Id,
|
||||
UserId: authInfo.UserId,
|
||||
AuthModule: authInfo.AuthModule,
|
||||
}
|
||||
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
|
||||
_, err := sess.Cols("created").Update(authInfo, cond)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
|
||||
authUser := &login.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
@ -239,13 +222,6 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAut
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
||||
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
|
||||
_, err := sess.Delete(cmd.UserAuth)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
|
||||
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
|
||||
var rawSQL = "DELETE FROM user_auth WHERE user_id = ?"
|
||||
@ -254,36 +230,6 @@ func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) er
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
|
||||
query := user.GetUserByIDQuery{ID: id}
|
||||
user, err := s.userService.GetByID(ctx, &query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
|
||||
query := user.GetUserByLoginQuery{LoginOrEmail: login}
|
||||
usr, err := s.userService.GetByLogin(ctx, &query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
|
||||
query := user.GetUserByEmailQuery{Email: email}
|
||||
usr, err := s.userService.GetByEmail(ctx, &query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it
|
||||
func (s *AuthInfoStore) decodeAndDecrypt(str string) (string, error) {
|
||||
// Bail out if empty string since it'll cause a segfault in Decrypt
|
||||
|
@ -5,12 +5,14 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
secretstest "github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
func TestIntegrationAuthInfoStore(t *testing.T) {
|
||||
@ -21,12 +23,79 @@ func TestIntegrationAuthInfoStore(t *testing.T) {
|
||||
sql := db.InitTestDB(t)
|
||||
store := ProvideAuthInfoStore(sql, secretstest.NewFakeSecretsService(), nil)
|
||||
|
||||
t.Run("should be able to auth lables for users", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
|
||||
AuthModule: login.LDAPAuthModule,
|
||||
AuthId: "1",
|
||||
UserId: 1,
|
||||
}))
|
||||
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
|
||||
AuthModule: login.AzureADAuthModule,
|
||||
AuthId: "1",
|
||||
UserId: 1,
|
||||
}))
|
||||
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
|
||||
AuthModule: login.GoogleAuthModule,
|
||||
AuthId: "10",
|
||||
UserId: 2,
|
||||
}))
|
||||
|
||||
labels, err := store.GetUserLabels(ctx, login.GetUserLabelsQuery{UserIDs: []int64{1, 2}})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, labels, 2)
|
||||
|
||||
require.Equal(t, login.AzureADAuthModule, labels[1])
|
||||
require.Equal(t, login.GoogleAuthModule, labels[2])
|
||||
})
|
||||
|
||||
t.Run("should always get the latest used", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
|
||||
AuthModule: login.LDAPAuthModule,
|
||||
AuthId: "1",
|
||||
UserId: 1,
|
||||
}))
|
||||
|
||||
defer func() {
|
||||
GetTime = time.Now
|
||||
}()
|
||||
|
||||
GetTime = func() time.Time {
|
||||
return time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
|
||||
AuthModule: login.AzureADAuthModule,
|
||||
AuthId: "2",
|
||||
UserId: 1,
|
||||
}))
|
||||
|
||||
info, err := store.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
|
||||
UserId: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, login.AzureADAuthModule, info.AuthModule)
|
||||
assert.Equal(t, "2", info.AuthId)
|
||||
})
|
||||
|
||||
t.Run("should return error when userID and authID is zero value", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
info, err := store.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
|
||||
AuthModule: login.GoogleAuthModule,
|
||||
})
|
||||
|
||||
require.ErrorIs(t, err, user.ErrUserNotFound)
|
||||
require.Nil(t, info)
|
||||
})
|
||||
|
||||
t.Run("should remove duplicates on update", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
setCmd := &login.SetAuthInfoCommand{
|
||||
AuthModule: login.GenericOAuthModule,
|
||||
AuthId: "1",
|
||||
UserId: 1,
|
||||
UserId: 10,
|
||||
}
|
||||
|
||||
require.NoError(t, store.SetAuthInfo(ctx, setCmd))
|
||||
|
@ -42,7 +42,7 @@ func InitDuplicateUserMetrics() {
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||
// if _, err := s.GetLoginStats(ctx); err != nil {
|
||||
// if _, err := s.getLoginStats(ctx); err != nil {
|
||||
// s.logger.Warn("Failed to get authinfo metrics", "error", err.Error())
|
||||
// }
|
||||
updateStatsTicker := time.NewTicker(login.MetricsCollectionInterval)
|
||||
@ -51,7 +51,7 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-updateStatsTicker.C:
|
||||
// if _, err := s.GetLoginStats(ctx); err != nil {
|
||||
// if _, err := s.getLoginStats(ctx); err != nil {
|
||||
// s.logger.Warn("Failed to get authinfo metrics", "error", nil)
|
||||
// }
|
||||
case <-ctx.Done():
|
||||
@ -60,7 +60,26 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
|
||||
func (s *AuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
|
||||
m := map[string]any{}
|
||||
|
||||
loginStats, err := s.getLoginStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get login stats", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries
|
||||
if loginStats.DuplicateUserEntries > 0 {
|
||||
m["stats.users.has_duplicate_user_entries"] = 1
|
||||
} else {
|
||||
m["stats.users.has_duplicate_user_entries"] = 0
|
||||
}
|
||||
m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) getLoginStats(ctx context.Context) (login.LoginStats, error) {
|
||||
var stats login.LoginStats
|
||||
outerErr := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
|
||||
rawSQL := `SELECT
|
||||
@ -86,25 +105,6 @@ func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, er
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
|
||||
m := map[string]any{}
|
||||
|
||||
loginStats, err := s.GetLoginStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get login stats", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries
|
||||
if loginStats.DuplicateUserEntries > 0 {
|
||||
m["stats.users.has_duplicate_user_entries"] = 1
|
||||
} else {
|
||||
m["stats.users.has_duplicate_user_entries"] = 0
|
||||
}
|
||||
m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *AuthInfoStore) duplicateUserEntriesSQL(ctx context.Context) string {
|
||||
userDialect := s.sqlStore.GetDialect().Quote("user")
|
||||
// this query counts how many users have the same login or email.
|
||||
|
76
pkg/services/login/authinfoservice/database/stats_test.go
Normal file
76
pkg/services/login/authinfoservice/database/stats_test.go
Normal file
@ -0,0 +1,76 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
secretstest "github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/userimpl"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
func TestIntegrationAuthInfoStoreStats(t *testing.T) {
|
||||
sql := db.InitTestDB(t)
|
||||
cfg := setting.NewCfg()
|
||||
InitDuplicateUserMetrics()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
usrStore := userimpl.ProvideStore(sql, cfg)
|
||||
for i := 0; i < 5; i++ {
|
||||
usr := &user.User{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Login: fmt.Sprint("user", i),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Created: now,
|
||||
Updated: now,
|
||||
LastSeenAt: now,
|
||||
}
|
||||
_, err := usrStore.Insert(context.Background(), usr)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
var (
|
||||
duplicatedUsers int
|
||||
mixedCasedUsers int
|
||||
hasDuplicatedUsers int
|
||||
)
|
||||
|
||||
if sql.GetDialect().DriverName() != "mysql" {
|
||||
duplicatedUsers, mixedCasedUsers, hasDuplicatedUsers = 2, 1, 1
|
||||
_, err := usrStore.Insert(context.Background(), &user.User{
|
||||
Email: "USERDUPLICATETEST1@TEST.COM",
|
||||
Name: "user name 1",
|
||||
Login: "USER_DUPLICATE_TEST_1_LOGIN",
|
||||
Created: now,
|
||||
Updated: now,
|
||||
LastSeenAt: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// add additional user with duplicate login where DOMAIN is upper case
|
||||
_, err = usrStore.Insert(context.Background(), &user.User{
|
||||
Email: "userduplicatetest1@test.com",
|
||||
Name: "user name 1",
|
||||
Login: "user_duplicate_test_1_login",
|
||||
Created: now,
|
||||
Updated: now,
|
||||
LastSeenAt: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
store := ProvideAuthInfoStore(sql, secretstest.NewFakeSecretsService(), nil)
|
||||
stats, err := store.CollectLoginStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, duplicatedUsers, stats["stats.users.duplicate_user_entries"])
|
||||
require.Equal(t, mixedCasedUsers, stats["stats.users.mixed_cased_users"])
|
||||
require.Equal(t, hasDuplicatedUsers, stats["stats.users.has_duplicate_user_entries"])
|
||||
}
|
@ -2,27 +2,20 @@ package authinfoservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
const genericOAuthModule = "oauth_generic_oauth"
|
||||
|
||||
type Implementation struct {
|
||||
UserProtectionService login.UserProtectionService
|
||||
authInfoStore login.Store
|
||||
logger log.Logger
|
||||
authInfoStore login.Store
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func ProvideAuthInfoService(userProtectionService login.UserProtectionService, authInfoStore login.Store, usageStats usagestats.Service) *Implementation {
|
||||
func ProvideAuthInfoService(authInfoStore login.Store) *Implementation {
|
||||
s := &Implementation{
|
||||
UserProtectionService: userProtectionService,
|
||||
authInfoStore: authInfoStore,
|
||||
logger: log.New("login.authinfo"),
|
||||
authInfoStore: authInfoStore,
|
||||
logger: log.New("login.authinfo"),
|
||||
}
|
||||
// FIXME: disabled metrics until further notice
|
||||
// query performance is slow for more than 20000 users
|
||||
@ -30,158 +23,6 @@ func ProvideAuthInfoService(userProtectionService login.UserProtectionService, a
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserByAuthInfoQuery) (bool, *user.User, *login.UserAuth, error) {
|
||||
authQuery := &login.GetAuthInfoQuery{}
|
||||
|
||||
// Try to find the user by auth module and id first
|
||||
if query.AuthModule != "" && query.AuthId != "" {
|
||||
authQuery.AuthModule = query.AuthModule
|
||||
authQuery.AuthId = query.AuthId
|
||||
|
||||
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
if !errors.Is(err, user.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
// if user id was specified and doesn't match the user_auth entry, remove it
|
||||
if query.UserLookupParams.UserID != nil &&
|
||||
*query.UserLookupParams.UserID != 0 &&
|
||||
*query.UserLookupParams.UserID != userAuth.UserId {
|
||||
if err := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
|
||||
UserAuth: userAuth,
|
||||
}); err != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", err)
|
||||
}
|
||||
|
||||
return false, nil, nil, user.ErrUserNotFound
|
||||
} else {
|
||||
usr, err := s.authInfoStore.GetUserById(ctx, userAuth.UserId)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// if the user has been deleted then remove the entry
|
||||
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
|
||||
UserAuth: userAuth,
|
||||
}); errDel != nil {
|
||||
s.logger.Error("Error removing user_auth entry", "error", errDel)
|
||||
}
|
||||
|
||||
return false, nil, nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return false, nil, nil, err
|
||||
}
|
||||
|
||||
return true, usr, userAuth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil, nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupByOneOf(ctx context.Context, params *login.UserLookupParams) (*user.User, error) {
|
||||
var usr *user.User
|
||||
var err error
|
||||
|
||||
// If not found, try to find the user by id
|
||||
if params.UserID != nil && *params.UserID != 0 {
|
||||
usr, err = s.authInfoStore.GetUserById(ctx, *params.UserID)
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by email address
|
||||
if usr == nil && params.Email != nil && *params.Email != "" {
|
||||
usr, err = s.authInfoStore.GetUserByEmail(ctx, *params.Email)
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by login
|
||||
if usr == nil && params.Login != nil && *params.Login != "" {
|
||||
usr, err = s.authInfoStore.GetUserByLogin(ctx, *params.Login)
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if usr == nil {
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*login.UserAuth, error) {
|
||||
if authModule == genericOAuthModule && userID != 0 {
|
||||
authQuery := &login.GetAuthInfoQuery{}
|
||||
authQuery.AuthModule = authModule
|
||||
authQuery.AuthId = authId
|
||||
authQuery.UserId = userID
|
||||
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userAuth, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) LookupAndUpdate(ctx context.Context, query *login.GetUserByAuthInfoQuery) (*user.User, error) {
|
||||
// 1. LookupAndFix = auth info, user, error
|
||||
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
|
||||
foundUser, usr, authInfo, err := s.LookupAndFix(ctx, query)
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. FindByUserDetails
|
||||
if !foundUser {
|
||||
usr, err = s.LookupByOneOf(ctx, &query.UserLookupParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.UserProtectionService.AllowUserMapping(usr, query.AuthModule); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Special case for generic oauth duplicates
|
||||
ai, err := s.GenericOAuthLookup(ctx, query.AuthModule, query.AuthId, usr.ID)
|
||||
if !errors.Is(err, user.ErrUserNotFound) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ai != nil {
|
||||
authInfo = ai
|
||||
}
|
||||
|
||||
if query.AuthModule != "" {
|
||||
if authInfo == nil {
|
||||
cmd := &login.SetAuthInfoCommand{
|
||||
UserId: usr.ID,
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
if err := s.authInfoStore.SetAuthInfo(ctx, cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := s.authInfoStore.UpdateAuthInfoDate(ctx, authInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
return s.authInfoStore.GetAuthInfo(ctx, query)
|
||||
}
|
||||
|
@ -1,568 +0,0 @@
|
||||
package authinfoservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
|
||||
"github.com/grafana/grafana/pkg/services/org/orgimpl"
|
||||
"github.com/grafana/grafana/pkg/services/quota/quotaimpl"
|
||||
"github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/userimpl"
|
||||
)
|
||||
|
||||
//nolint:goconst
|
||||
func TestUserAuth(t *testing.T) {
|
||||
sqlStore := db.InitTestDB(t)
|
||||
authInfoStore := newFakeAuthInfoStore()
|
||||
srv := ProvideAuthInfoService(
|
||||
&OSSUserProtectionImpl{},
|
||||
authInfoStore,
|
||||
&usagestats.UsageStatsMock{},
|
||||
)
|
||||
|
||||
t.Run("Given 5 users", func(t *testing.T) {
|
||||
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
|
||||
require.NoError(t, err)
|
||||
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
_, err := usrSvc.Create(context.Background(), &cmd)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
t.Run("Can find existing user", func(t *testing.T) {
|
||||
// By Login
|
||||
userlogin := "loginuser0"
|
||||
authInfoStore.ExpectedUser = &user.User{
|
||||
Login: "loginuser0",
|
||||
ID: 1,
|
||||
Email: "user1@test.com",
|
||||
}
|
||||
query := &login.GetUserByAuthInfoQuery{UserLookupParams: login.UserLookupParams{Login: &userlogin}}
|
||||
usr, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Login, userlogin)
|
||||
|
||||
// By ID
|
||||
id := usr.ID
|
||||
|
||||
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
|
||||
UserID: &id,
|
||||
})
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.ID, id)
|
||||
|
||||
// By Email
|
||||
email := "user1@test.com"
|
||||
|
||||
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
|
||||
Email: &email,
|
||||
})
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Email, email)
|
||||
|
||||
authInfoStore.ExpectedUser = nil
|
||||
// Don't find nonexistent user
|
||||
email = "nonexistent@test.com"
|
||||
|
||||
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
|
||||
Email: &email,
|
||||
})
|
||||
|
||||
require.Equal(t, user.ErrUserNotFound, err)
|
||||
require.Nil(t, usr)
|
||||
})
|
||||
|
||||
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
|
||||
// get nonexistent user_auth entry
|
||||
authInfoStore.ExpectedUser = &user.User{}
|
||||
authInfoStore.ExpectedError = user.ErrUserNotFound
|
||||
query := &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
usr, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Equal(t, user.ErrUserNotFound, err)
|
||||
require.Nil(t, usr)
|
||||
|
||||
// create user_auth entry
|
||||
userlogin := "loginuser0"
|
||||
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
|
||||
authInfoStore.ExpectedError = nil
|
||||
authInfoStore.ExpectedOAuth = &login.UserAuth{Id: 1}
|
||||
query.UserLookupParams.Login = &userlogin
|
||||
usr, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Login, userlogin)
|
||||
|
||||
// get via user_auth
|
||||
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
usr, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Login, userlogin)
|
||||
|
||||
// get with non-matching id
|
||||
idPlusOne := usr.ID + 1
|
||||
|
||||
authInfoStore.ExpectedUser.Login = "loginuser1"
|
||||
query.UserLookupParams.UserID = &idPlusOne
|
||||
usr, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Login, "loginuser1")
|
||||
|
||||
// get via user_auth
|
||||
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
usr, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, usr.Login, "loginuser1")
|
||||
|
||||
// remove user
|
||||
err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
|
||||
_, err := sess.Exec("DELETE FROM "+sqlStore.Dialect.Quote("user")+" WHERE id=?", usr.ID)
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
authInfoStore.ExpectedUser = nil
|
||||
authInfoStore.ExpectedError = user.ErrUserNotFound
|
||||
// get via user_auth for deleted user
|
||||
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
|
||||
usr, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Equal(t, err, user.ErrUserNotFound)
|
||||
require.Nil(t, usr)
|
||||
})
|
||||
|
||||
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
idToken := "testidtoken"
|
||||
token = token.WithExtra(map[string]any{"id_token": idToken})
|
||||
|
||||
// Find a user to set tokens on
|
||||
userlogin := "loginuser0"
|
||||
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
|
||||
authInfoStore.ExpectedError = nil
|
||||
authInfoStore.ExpectedOAuth = &login.UserAuth{
|
||||
Id: 1,
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthTokenType: token.TokenType,
|
||||
OAuthIdToken: idToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
}
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
query := &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
|
||||
cmd := &login.UpdateAuthInfoCommand{
|
||||
UserId: user.ID,
|
||||
AuthId: query.AuthId,
|
||||
AuthModule: query.AuthModule,
|
||||
OAuthToken: token,
|
||||
}
|
||||
err = srv.authInfoStore.UpdateAuthInfo(context.Background(), cmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
|
||||
getAuthQuery := &login.GetAuthInfoQuery{
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
authInfo, err := srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, token.AccessToken, authInfo.OAuthAccessToken)
|
||||
require.Equal(t, token.RefreshToken, authInfo.OAuthRefreshToken)
|
||||
require.Equal(t, token.TokenType, authInfo.OAuthTokenType)
|
||||
require.Equal(t, idToken, authInfo.OAuthIdToken)
|
||||
})
|
||||
|
||||
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
|
||||
// Restore after destructive operation
|
||||
sqlStore = db.InitTestDB(t)
|
||||
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
|
||||
require.NoError(t, err)
|
||||
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
_, err = usrSvc.Create(context.Background(), &cmd)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Find a user to set tokens on
|
||||
userlogin := "loginuser0"
|
||||
|
||||
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &login.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
database.GetTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
|
||||
// Add a second auth module for this user
|
||||
// Have this module's last log-in be more recent
|
||||
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
|
||||
query = &login.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
database.GetTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test2"
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery := &login.GetAuthInfoQuery{
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, authInfo.AuthModule, "test2")
|
||||
|
||||
// "log in" again with the first auth module
|
||||
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1"}
|
||||
err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd)
|
||||
|
||||
require.Nil(t, err)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test1"
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery = &login.GetAuthInfoQuery{
|
||||
UserId: user.ID,
|
||||
}
|
||||
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, authInfo.AuthModule, "test1")
|
||||
})
|
||||
|
||||
t.Run("Keeps track of last used auth_module when not using oauth", func(t *testing.T) {
|
||||
// Restore after destructive operation
|
||||
sqlStore = db.InitTestDB(t)
|
||||
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
|
||||
require.NoError(t, err)
|
||||
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
}
|
||||
_, err := usrSvc.Create(context.Background(), &cmd)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// Find a user to set tokens on
|
||||
userlogin := "loginuser0"
|
||||
|
||||
fixedTime := time.Now()
|
||||
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
|
||||
// Make the first log-in during the past
|
||||
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -2) }
|
||||
queryOne := &login.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err := srv.LookupAndUpdate(context.Background(), queryOne)
|
||||
database.GetTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
|
||||
// Add a second auth module for this user
|
||||
// Have this module's last log-in be more recent
|
||||
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -1) }
|
||||
queryTwo := &login.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err = srv.LookupAndUpdate(context.Background(), queryTwo)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
|
||||
// Get the latest entry by not supply an authmodule or authid
|
||||
getAuthQuery := &login.GetAuthInfoQuery{
|
||||
UserId: user.ID,
|
||||
}
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test2"
|
||||
|
||||
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test2", authInfo.AuthModule)
|
||||
|
||||
// Now reuse first auth module and make sure it's updated to the most recent
|
||||
database.GetTime = func() time.Time { return fixedTime }
|
||||
|
||||
// add oauth info to auth_info to make sure update date does not overwrite it
|
||||
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1", OAuthToken: &oauth2.Token{
|
||||
AccessToken: "access_token",
|
||||
TokenType: "token_type",
|
||||
RefreshToken: "refresh_token",
|
||||
Expiry: fixedTime,
|
||||
}}
|
||||
err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd)
|
||||
require.Nil(t, err)
|
||||
user, err = srv.LookupAndUpdate(context.Background(), queryOne)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test1"
|
||||
authInfoStore.ExpectedOAuth.OAuthAccessToken = "access_token"
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test1", authInfo.AuthModule)
|
||||
// make sure oauth info is not overwritten by update date
|
||||
require.Equal(t, "access_token", authInfo.OAuthAccessToken)
|
||||
|
||||
// Now reuse second auth module and make sure it's updated to the most recent
|
||||
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, 1) }
|
||||
user, err = srv.LookupAndUpdate(context.Background(), queryTwo)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test2"
|
||||
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test2", authInfo.AuthModule)
|
||||
|
||||
// Ensure test 1 did not have its entry modified
|
||||
getAuthQueryUnchanged := &login.GetAuthInfoQuery{
|
||||
UserId: user.ID,
|
||||
AuthModule: "test1",
|
||||
}
|
||||
authInfoStore.ExpectedOAuth.AuthModule = "test1"
|
||||
|
||||
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "test1", authInfo.AuthModule)
|
||||
})
|
||||
|
||||
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
|
||||
// Find a user to set tokens on
|
||||
userlogin := "loginuser0"
|
||||
|
||||
// Expect to pass since there's a matching login user
|
||||
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query := &login.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: login.UserLookupParams{
|
||||
Login: &userlogin,
|
||||
}}
|
||||
user, err := srv.LookupAndUpdate(context.Background(), query)
|
||||
database.GetTime = time.Now
|
||||
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, user.Login, userlogin)
|
||||
|
||||
otherLoginUser := "aloginuser"
|
||||
// Should throw a "user not found" error since there's no matching login user
|
||||
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
|
||||
query = &login.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: login.UserLookupParams{
|
||||
Login: &otherLoginUser,
|
||||
}}
|
||||
authInfoStore.ExpectedError = errors.New("some error")
|
||||
|
||||
user, err = srv.LookupAndUpdate(context.Background(), query)
|
||||
database.GetTime = time.Now
|
||||
|
||||
require.NotNil(t, err)
|
||||
require.Nil(t, user)
|
||||
authInfoStore.ExpectedError = nil
|
||||
})
|
||||
|
||||
t.Run("should be able to run loginstats query in all dbs", func(t *testing.T) {
|
||||
// we need to see that we can run queries for all db
|
||||
// as it is only a concern for postgres/sqllite3
|
||||
// where we have duplicate users
|
||||
|
||||
// Restore after destructive operation
|
||||
sqlStore = db.InitTestDB(t)
|
||||
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
|
||||
require.NoError(t, err)
|
||||
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
OrgID: 1,
|
||||
}
|
||||
_, err := usrSvc.Create(context.Background(), &cmd)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
_, err = srv.authInfoStore.GetLoginStats(context.Background())
|
||||
require.Nil(t, err)
|
||||
})
|
||||
|
||||
t.Run("calculate metrics on duplicate userstats", func(t *testing.T) {
|
||||
// Restore after destructive operation
|
||||
sqlStore = db.InitTestDB(t)
|
||||
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
|
||||
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
|
||||
require.NoError(t, err)
|
||||
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
cmd := user.CreateUserCommand{
|
||||
Email: fmt.Sprint("user", i, "@test.com"),
|
||||
Name: fmt.Sprint("user", i),
|
||||
Login: fmt.Sprint("loginuser", i),
|
||||
OrgID: 1,
|
||||
}
|
||||
_, err := usrSvc.Create(context.Background(), &cmd)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
// "Skipping duplicate users test for mysql as it does make unique constraint case insensitive by default
|
||||
if sqlStore.GetDialect().DriverName() != "mysql" {
|
||||
dupUserEmailcmd := user.CreateUserCommand{
|
||||
Email: "USERDUPLICATETEST1@TEST.COM",
|
||||
Name: "user name 1",
|
||||
Login: "USER_DUPLICATE_TEST_1_LOGIN",
|
||||
}
|
||||
_, err := usrSvc.Create(context.Background(), &dupUserEmailcmd)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add additional user with duplicate login where DOMAIN is upper case
|
||||
dupUserLogincmd := user.CreateUserCommand{
|
||||
Email: "userduplicatetest1@test.com",
|
||||
Name: "user name 1",
|
||||
Login: "user_duplicate_test_1_login",
|
||||
}
|
||||
_, err = usrSvc.Create(context.Background(), &dupUserLogincmd)
|
||||
require.NoError(t, err)
|
||||
authInfoStore.ExpectedUser = &user.User{
|
||||
Email: "userduplicatetest1@test.com",
|
||||
Name: "user name 1",
|
||||
Login: "user_duplicate_test_1_login",
|
||||
}
|
||||
authInfoStore.ExpectedDuplicateUserEntries = 2
|
||||
authInfoStore.ExpectedHasDuplicateUserEntries = 1
|
||||
authInfoStore.ExpectedLoginStats = login.LoginStats{
|
||||
DuplicateUserEntries: 2,
|
||||
MixedCasedUsers: 1,
|
||||
}
|
||||
// require metrics and statistics to be 2
|
||||
m, err := srv.authInfoStore.CollectLoginStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, m["stats.users.duplicate_user_entries"])
|
||||
require.Equal(t, 1, m["stats.users.has_duplicate_user_entries"])
|
||||
|
||||
require.Equal(t, 1, m["stats.users.mixed_cased_users"])
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type FakeAuthInfoStore struct {
|
||||
login.AuthInfoService
|
||||
ExpectedError error
|
||||
ExpectedUser *user.User
|
||||
ExpectedOAuth *login.UserAuth
|
||||
ExpectedDuplicateUserEntries int
|
||||
ExpectedHasDuplicateUserEntries int
|
||||
ExpectedLoginStats login.LoginStats
|
||||
}
|
||||
|
||||
func newFakeAuthInfoStore() *FakeAuthInfoStore {
|
||||
return &FakeAuthInfoStore{}
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
return f.ExpectedOAuth, f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
|
||||
var res = make(map[string]any)
|
||||
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
|
||||
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
|
||||
res["stats.users.duplicate_user_entries_by_login"] = 0
|
||||
res["stats.users.has_duplicate_user_entries_by_login"] = 0
|
||||
res["stats.users.duplicate_user_entries_by_email"] = 0
|
||||
res["stats.users.has_duplicate_user_entries_by_email"] = 0
|
||||
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
|
||||
return res, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
|
||||
return f.ExpectedLoginStats, f.ExpectedError
|
||||
}
|
@ -4,14 +4,12 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
type AuthInfoServiceFake struct {
|
||||
login.AuthInfoService
|
||||
LatestUserID int64
|
||||
ExpectedUserAuth *login.UserAuth
|
||||
ExpectedUser *user.User
|
||||
ExpectedExternalUser *login.ExternalUserInfo
|
||||
ExpectedError error
|
||||
ExpectedLabels map[int64]string
|
||||
@ -20,15 +18,6 @@ type AuthInfoServiceFake struct {
|
||||
UpdateAuthInfoFn func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *login.GetUserByAuthInfoQuery) (*user.User, error) {
|
||||
if query.UserLookupParams.UserID != nil {
|
||||
a.LatestUserID = *query.UserLookupParams.UserID
|
||||
} else {
|
||||
a.LatestUserID = 0
|
||||
}
|
||||
return a.ExpectedUser, a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
||||
a.LatestUserID = query.UserId
|
||||
return a.ExpectedUserAuth, a.ExpectedError
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/socialtest"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||
@ -230,7 +229,7 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social
|
||||
}
|
||||
|
||||
authInfoStore := &FakeAuthInfoStore{}
|
||||
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
|
||||
authInfoService := authinfoservice.ProvideAuthInfoService(authInfoStore)
|
||||
return &Service{
|
||||
Cfg: setting.NewCfg(),
|
||||
SocialService: socialService,
|
||||
|
Loading…
Reference in New Issue
Block a user