Login: remove unused function (#78442)

* Move test to the db so we test the queries and not just testing the mock

* Remove unused function and dependencies

* Remove unused functions from the database

* Add some integration tests
This commit is contained in:
Karl Persson 2023-11-21 11:44:13 +01:00 committed by GitHub
parent e2f2d8b3d6
commit d42201dbf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 174 additions and 836 deletions

View File

@ -18,7 +18,6 @@ import (
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/db/dbtest"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/login/socialtest"
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
@ -66,9 +65,7 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) {
secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore))
authInfoStore := authinfostore.ProvideAuthInfoStore(sqlStore, secretsService, userMock)
srv := authinfoservice.ProvideAuthInfoService(
&authinfoservice.OSSUserProtectionImpl{},
authInfoStore,
&usagestats.UsageStatsMock{},
)
hs.authInfoService = srv
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, quotatest.New(false, nil))

View File

@ -33,7 +33,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
userProtection := &authinfoservice.OSSUserProtectionImpl{}
authFakeNil := &logintest.AuthInfoServiceFake{
ExpectedUser: nil,
ExpectedError: user.ErrUserNotFound,
SetAuthInfoFn: func(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
return nil
@ -43,7 +42,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
},
}
authFakeUserID := &logintest.AuthInfoServiceFake{
ExpectedUser: nil,
ExpectedError: nil,
ExpectedUserAuth: &login.UserAuth{
AuthModule: "oauth",
@ -68,7 +66,6 @@ func TestUserSync_SyncUserHook(t *testing.T) {
}}
userServiceNil := &usertest.FakeUserService{
ExpectedUser: nil,
ExpectedError: user.ErrUserNotFound,
CreateFn: func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error) {
return &user.User{

View File

@ -3,12 +3,10 @@ package login
import (
"context"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
type AuthInfoService interface {
LookupAndUpdate(ctx context.Context, query *GetUserByAuthInfoQuery) (*user.User, error)
GetAuthInfo(ctx context.Context, query *GetAuthInfoQuery) (*UserAuth, error)
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
@ -21,15 +19,9 @@ type Store interface {
GetUserLabels(ctx context.Context, query GetUserLabelsQuery) (map[int64]string, error)
SetAuthInfo(ctx context.Context, cmd *SetAuthInfoCommand) error
UpdateAuthInfo(ctx context.Context, cmd *UpdateAuthInfoCommand) error
UpdateAuthInfoDate(ctx context.Context, authInfo *UserAuth) error
DeleteAuthInfo(ctx context.Context, cmd *DeleteAuthInfoCommand) error
DeleteUserAuthInfo(ctx context.Context, userID int64) error
GetUserById(ctx context.Context, id int64) (*user.User, error)
GetUserByLogin(ctx context.Context, login string) (*user.User, error)
GetUserByEmail(ctx context.Context, email string) (*user.User, error)
CollectLoginStats(ctx context.Context) (map[string]any, error)
RunMetricsCollection(ctx context.Context) error
GetLoginStats(ctx context.Context) (LoginStats, error)
}
const (

View File

@ -153,23 +153,6 @@ func (s *AuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoC
})
}
// UpdateAuthInfoDate updates the auth info for the user with the latest date.
// Avoids overlapping entries hiding the last used one (ex: LDAP->SAML->LDAP).
func (s *AuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
authInfo.Created = GetTime()
cond := &login.UserAuth{
Id: authInfo.Id,
UserId: authInfo.UserId,
AuthModule: authInfo.AuthModule,
}
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Cols("created").Update(authInfo, cond)
return err
})
}
func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
authUser := &login.UserAuth{
UserId: cmd.UserId,
@ -239,13 +222,6 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAut
})
}
func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(cmd.UserAuth)
return err
})
}
func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
var rawSQL = "DELETE FROM user_auth WHERE user_id = ?"
@ -254,36 +230,6 @@ func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) er
})
}
func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
query := user.GetUserByIDQuery{ID: id}
user, err := s.userService.GetByID(ctx, &query)
if err != nil {
return nil, err
}
return user, nil
}
func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
query := user.GetUserByLoginQuery{LoginOrEmail: login}
usr, err := s.userService.GetByLogin(ctx, &query)
if err != nil {
return nil, err
}
return usr, nil
}
func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
query := user.GetUserByEmailQuery{Email: email}
usr, err := s.userService.GetByEmail(ctx, &query)
if err != nil {
return nil, err
}
return usr, nil
}
// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it
func (s *AuthInfoStore) decodeAndDecrypt(str string) (string, error) {
// Bail out if empty string since it'll cause a segfault in Decrypt

View File

@ -5,12 +5,14 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/services/login"
secretstest "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/user"
)
func TestIntegrationAuthInfoStore(t *testing.T) {
@ -21,12 +23,79 @@ func TestIntegrationAuthInfoStore(t *testing.T) {
sql := db.InitTestDB(t)
store := ProvideAuthInfoStore(sql, secretstest.NewFakeSecretsService(), nil)
t.Run("should be able to auth lables for users", func(t *testing.T) {
ctx := context.Background()
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
AuthModule: login.LDAPAuthModule,
AuthId: "1",
UserId: 1,
}))
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
AuthModule: login.AzureADAuthModule,
AuthId: "1",
UserId: 1,
}))
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
AuthModule: login.GoogleAuthModule,
AuthId: "10",
UserId: 2,
}))
labels, err := store.GetUserLabels(ctx, login.GetUserLabelsQuery{UserIDs: []int64{1, 2}})
require.NoError(t, err)
require.Len(t, labels, 2)
require.Equal(t, login.AzureADAuthModule, labels[1])
require.Equal(t, login.GoogleAuthModule, labels[2])
})
t.Run("should always get the latest used", func(t *testing.T) {
ctx := context.Background()
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
AuthModule: login.LDAPAuthModule,
AuthId: "1",
UserId: 1,
}))
defer func() {
GetTime = time.Now
}()
GetTime = func() time.Time {
return time.Now().Add(1 * time.Hour)
}
require.NoError(t, store.SetAuthInfo(ctx, &login.SetAuthInfoCommand{
AuthModule: login.AzureADAuthModule,
AuthId: "2",
UserId: 1,
}))
info, err := store.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
UserId: 1,
})
require.NoError(t, err)
assert.Equal(t, login.AzureADAuthModule, info.AuthModule)
assert.Equal(t, "2", info.AuthId)
})
t.Run("should return error when userID and authID is zero value", func(t *testing.T) {
ctx := context.Background()
info, err := store.GetAuthInfo(ctx, &login.GetAuthInfoQuery{
AuthModule: login.GoogleAuthModule,
})
require.ErrorIs(t, err, user.ErrUserNotFound)
require.Nil(t, info)
})
t.Run("should remove duplicates on update", func(t *testing.T) {
ctx := context.Background()
setCmd := &login.SetAuthInfoCommand{
AuthModule: login.GenericOAuthModule,
AuthId: "1",
UserId: 1,
UserId: 10,
}
require.NoError(t, store.SetAuthInfo(ctx, setCmd))

View File

@ -42,7 +42,7 @@ func InitDuplicateUserMetrics() {
}
func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
// if _, err := s.GetLoginStats(ctx); err != nil {
// if _, err := s.getLoginStats(ctx); err != nil {
// s.logger.Warn("Failed to get authinfo metrics", "error", err.Error())
// }
updateStatsTicker := time.NewTicker(login.MetricsCollectionInterval)
@ -51,7 +51,7 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
for {
select {
case <-updateStatsTicker.C:
// if _, err := s.GetLoginStats(ctx); err != nil {
// if _, err := s.getLoginStats(ctx); err != nil {
// s.logger.Warn("Failed to get authinfo metrics", "error", nil)
// }
case <-ctx.Done():
@ -60,7 +60,26 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
}
}
func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
func (s *AuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
m := map[string]any{}
loginStats, err := s.getLoginStats(ctx)
if err != nil {
s.logger.Error("Failed to get login stats", "error", err)
return nil, err
}
m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries
if loginStats.DuplicateUserEntries > 0 {
m["stats.users.has_duplicate_user_entries"] = 1
} else {
m["stats.users.has_duplicate_user_entries"] = 0
}
m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers
return m, nil
}
func (s *AuthInfoStore) getLoginStats(ctx context.Context) (login.LoginStats, error) {
var stats login.LoginStats
outerErr := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
rawSQL := `SELECT
@ -86,25 +105,6 @@ func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, er
return stats, nil
}
func (s *AuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
m := map[string]any{}
loginStats, err := s.GetLoginStats(ctx)
if err != nil {
s.logger.Error("Failed to get login stats", "error", err)
return nil, err
}
m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries
if loginStats.DuplicateUserEntries > 0 {
m["stats.users.has_duplicate_user_entries"] = 1
} else {
m["stats.users.has_duplicate_user_entries"] = 0
}
m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers
return m, nil
}
func (s *AuthInfoStore) duplicateUserEntriesSQL(ctx context.Context) string {
userDialect := s.sqlStore.GetDialect().Quote("user")
// this query counts how many users have the same login or email.

View File

@ -0,0 +1,76 @@
package database
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/db"
secretstest "github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/userimpl"
"github.com/grafana/grafana/pkg/setting"
)
func TestIntegrationAuthInfoStoreStats(t *testing.T) {
sql := db.InitTestDB(t)
cfg := setting.NewCfg()
InitDuplicateUserMetrics()
now := time.Now()
usrStore := userimpl.ProvideStore(sql, cfg)
for i := 0; i < 5; i++ {
usr := &user.User{
Email: fmt.Sprint("user", i, "@test.com"),
Login: fmt.Sprint("user", i),
Name: fmt.Sprint("user", i),
Created: now,
Updated: now,
LastSeenAt: now,
}
_, err := usrStore.Insert(context.Background(), usr)
require.Nil(t, err)
}
var (
duplicatedUsers int
mixedCasedUsers int
hasDuplicatedUsers int
)
if sql.GetDialect().DriverName() != "mysql" {
duplicatedUsers, mixedCasedUsers, hasDuplicatedUsers = 2, 1, 1
_, err := usrStore.Insert(context.Background(), &user.User{
Email: "USERDUPLICATETEST1@TEST.COM",
Name: "user name 1",
Login: "USER_DUPLICATE_TEST_1_LOGIN",
Created: now,
Updated: now,
LastSeenAt: now,
})
require.NoError(t, err)
// add additional user with duplicate login where DOMAIN is upper case
_, err = usrStore.Insert(context.Background(), &user.User{
Email: "userduplicatetest1@test.com",
Name: "user name 1",
Login: "user_duplicate_test_1_login",
Created: now,
Updated: now,
LastSeenAt: now,
})
require.NoError(t, err)
}
store := ProvideAuthInfoStore(sql, secretstest.NewFakeSecretsService(), nil)
stats, err := store.CollectLoginStats(context.Background())
require.NoError(t, err)
require.Equal(t, duplicatedUsers, stats["stats.users.duplicate_user_entries"])
require.Equal(t, mixedCasedUsers, stats["stats.users.mixed_cased_users"])
require.Equal(t, hasDuplicatedUsers, stats["stats.users.has_duplicate_user_entries"])
}

View File

@ -2,27 +2,20 @@ package authinfoservice
import (
"context"
"errors"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
)
const genericOAuthModule = "oauth_generic_oauth"
type Implementation struct {
UserProtectionService login.UserProtectionService
authInfoStore login.Store
logger log.Logger
authInfoStore login.Store
logger log.Logger
}
func ProvideAuthInfoService(userProtectionService login.UserProtectionService, authInfoStore login.Store, usageStats usagestats.Service) *Implementation {
func ProvideAuthInfoService(authInfoStore login.Store) *Implementation {
s := &Implementation{
UserProtectionService: userProtectionService,
authInfoStore: authInfoStore,
logger: log.New("login.authinfo"),
authInfoStore: authInfoStore,
logger: log.New("login.authinfo"),
}
// FIXME: disabled metrics until further notice
// query performance is slow for more than 20000 users
@ -30,158 +23,6 @@ func ProvideAuthInfoService(userProtectionService login.UserProtectionService, a
return s
}
func (s *Implementation) LookupAndFix(ctx context.Context, query *login.GetUserByAuthInfoQuery) (bool, *user.User, *login.UserAuth, error) {
authQuery := &login.GetAuthInfoQuery{}
// Try to find the user by auth module and id first
if query.AuthModule != "" && query.AuthId != "" {
authQuery.AuthModule = query.AuthModule
authQuery.AuthId = query.AuthId
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
if !errors.Is(err, user.ErrUserNotFound) {
if err != nil {
return false, nil, nil, err
}
// if user id was specified and doesn't match the user_auth entry, remove it
if query.UserLookupParams.UserID != nil &&
*query.UserLookupParams.UserID != 0 &&
*query.UserLookupParams.UserID != userAuth.UserId {
if err := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
UserAuth: userAuth,
}); err != nil {
s.logger.Error("Error removing user_auth entry", "error", err)
}
return false, nil, nil, user.ErrUserNotFound
} else {
usr, err := s.authInfoStore.GetUserById(ctx, userAuth.UserId)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// if the user has been deleted then remove the entry
if errDel := s.authInfoStore.DeleteAuthInfo(ctx, &login.DeleteAuthInfoCommand{
UserAuth: userAuth,
}); errDel != nil {
s.logger.Error("Error removing user_auth entry", "error", errDel)
}
return false, nil, nil, user.ErrUserNotFound
}
return false, nil, nil, err
}
return true, usr, userAuth, nil
}
}
}
return false, nil, nil, user.ErrUserNotFound
}
func (s *Implementation) LookupByOneOf(ctx context.Context, params *login.UserLookupParams) (*user.User, error) {
var usr *user.User
var err error
// If not found, try to find the user by id
if params.UserID != nil && *params.UserID != 0 {
usr, err = s.authInfoStore.GetUserById(ctx, *params.UserID)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return nil, err
}
}
// If not found, try to find the user by email address
if usr == nil && params.Email != nil && *params.Email != "" {
usr, err = s.authInfoStore.GetUserByEmail(ctx, *params.Email)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return nil, err
}
}
// If not found, try to find the user by login
if usr == nil && params.Login != nil && *params.Login != "" {
usr, err = s.authInfoStore.GetUserByLogin(ctx, *params.Login)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return nil, err
}
}
if usr == nil {
return nil, user.ErrUserNotFound
}
return usr, nil
}
func (s *Implementation) GenericOAuthLookup(ctx context.Context, authModule string, authId string, userID int64) (*login.UserAuth, error) {
if authModule == genericOAuthModule && userID != 0 {
authQuery := &login.GetAuthInfoQuery{}
authQuery.AuthModule = authModule
authQuery.AuthId = authId
authQuery.UserId = userID
userAuth, err := s.authInfoStore.GetAuthInfo(ctx, authQuery)
if err != nil {
return nil, err
}
return userAuth, nil
}
return nil, nil
}
func (s *Implementation) LookupAndUpdate(ctx context.Context, query *login.GetUserByAuthInfoQuery) (*user.User, error) {
// 1. LookupAndFix = auth info, user, error
// TODO: Not a big fan of the fact that we are deleting auth info here, might want to move that
foundUser, usr, authInfo, err := s.LookupAndFix(ctx, query)
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
return nil, err
}
// 2. FindByUserDetails
if !foundUser {
usr, err = s.LookupByOneOf(ctx, &query.UserLookupParams)
if err != nil {
return nil, err
}
}
if err := s.UserProtectionService.AllowUserMapping(usr, query.AuthModule); err != nil {
return nil, err
}
// Special case for generic oauth duplicates
ai, err := s.GenericOAuthLookup(ctx, query.AuthModule, query.AuthId, usr.ID)
if !errors.Is(err, user.ErrUserNotFound) {
if err != nil {
return nil, err
}
}
if ai != nil {
authInfo = ai
}
if query.AuthModule != "" {
if authInfo == nil {
cmd := &login.SetAuthInfoCommand{
UserId: usr.ID,
AuthModule: query.AuthModule,
AuthId: query.AuthId,
}
if err := s.authInfoStore.SetAuthInfo(ctx, cmd); err != nil {
return nil, err
}
} else {
if err := s.authInfoStore.UpdateAuthInfoDate(ctx, authInfo); err != nil {
return nil, err
}
}
}
return usr, nil
}
func (s *Implementation) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return s.authInfoStore.GetAuthInfo(ctx, query)
}

View File

@ -1,568 +0,0 @@
package authinfoservice
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/org/orgimpl"
"github.com/grafana/grafana/pkg/services/quota/quotaimpl"
"github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/userimpl"
)
//nolint:goconst
func TestUserAuth(t *testing.T) {
sqlStore := db.InitTestDB(t)
authInfoStore := newFakeAuthInfoStore()
srv := ProvideAuthInfoService(
&OSSUserProtectionImpl{},
authInfoStore,
&usagestats.UsageStatsMock{},
)
t.Run("Given 5 users", func(t *testing.T) {
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
require.NoError(t, err)
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
require.NoError(t, err)
for i := 0; i < 5; i++ {
cmd := user.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
}
_, err := usrSvc.Create(context.Background(), &cmd)
require.Nil(t, err)
}
t.Run("Can find existing user", func(t *testing.T) {
// By Login
userlogin := "loginuser0"
authInfoStore.ExpectedUser = &user.User{
Login: "loginuser0",
ID: 1,
Email: "user1@test.com",
}
query := &login.GetUserByAuthInfoQuery{UserLookupParams: login.UserLookupParams{Login: &userlogin}}
usr, err := srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, usr.Login, userlogin)
// By ID
id := usr.ID
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
UserID: &id,
})
require.Nil(t, err)
require.Equal(t, usr.ID, id)
// By Email
email := "user1@test.com"
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
Email: &email,
})
require.Nil(t, err)
require.Equal(t, usr.Email, email)
authInfoStore.ExpectedUser = nil
// Don't find nonexistent user
email = "nonexistent@test.com"
usr, err = srv.LookupByOneOf(context.Background(), &login.UserLookupParams{
Email: &email,
})
require.Equal(t, user.ErrUserNotFound, err)
require.Nil(t, usr)
})
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
// get nonexistent user_auth entry
authInfoStore.ExpectedUser = &user.User{}
authInfoStore.ExpectedError = user.ErrUserNotFound
query := &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err := srv.LookupAndUpdate(context.Background(), query)
require.Equal(t, user.ErrUserNotFound, err)
require.Nil(t, usr)
// create user_auth entry
userlogin := "loginuser0"
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
authInfoStore.ExpectedError = nil
authInfoStore.ExpectedOAuth = &login.UserAuth{Id: 1}
query.UserLookupParams.Login = &userlogin
usr, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, usr.Login, userlogin)
// get via user_auth
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, usr.Login, userlogin)
// get with non-matching id
idPlusOne := usr.ID + 1
authInfoStore.ExpectedUser.Login = "loginuser1"
query.UserLookupParams.UserID = &idPlusOne
usr, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, usr.Login, "loginuser1")
// get via user_auth
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err = srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, usr.Login, "loginuser1")
// remove user
err = sqlStore.WithDbSession(context.Background(), func(sess *db.Session) error {
_, err := sess.Exec("DELETE FROM "+sqlStore.Dialect.Quote("user")+" WHERE id=?", usr.ID)
return err
})
require.NoError(t, err)
authInfoStore.ExpectedUser = nil
authInfoStore.ExpectedError = user.ErrUserNotFound
// get via user_auth for deleted user
query = &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err = srv.LookupAndUpdate(context.Background(), query)
require.Equal(t, err, user.ErrUserNotFound)
require.Nil(t, usr)
})
t.Run("Can set & retrieve oauth token information", func(t *testing.T) {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now(),
TokenType: "Bearer",
}
idToken := "testidtoken"
token = token.WithExtra(map[string]any{"id_token": idToken})
// Find a user to set tokens on
userlogin := "loginuser0"
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
authInfoStore.ExpectedError = nil
authInfoStore.ExpectedOAuth = &login.UserAuth{
Id: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthTokenType: token.TokenType,
OAuthIdToken: idToken,
OAuthExpiry: token.Expiry,
}
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
query := &login.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err := srv.LookupAndUpdate(context.Background(), query)
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
cmd := &login.UpdateAuthInfoCommand{
UserId: user.ID,
AuthId: query.AuthId,
AuthModule: query.AuthModule,
OAuthToken: token,
}
err = srv.authInfoStore.UpdateAuthInfo(context.Background(), cmd)
require.Nil(t, err)
getAuthQuery := &login.GetAuthInfoQuery{
UserId: user.ID,
}
authInfo, err := srv.authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, token.AccessToken, authInfo.OAuthAccessToken)
require.Equal(t, token.RefreshToken, authInfo.OAuthRefreshToken)
require.Equal(t, token.TokenType, authInfo.OAuthTokenType)
require.Equal(t, idToken, authInfo.OAuthIdToken)
})
t.Run("Always return the most recently used auth_module", func(t *testing.T) {
// Restore after destructive operation
sqlStore = db.InitTestDB(t)
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
require.NoError(t, err)
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
require.NoError(t, err)
for i := 0; i < 5; i++ {
cmd := user.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
}
_, err = usrSvc.Create(context.Background(), &cmd)
require.NoError(t, err)
}
// Find a user to set tokens on
userlogin := "loginuser0"
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
// Make the first log-in during the past
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &login.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err := srv.LookupAndUpdate(context.Background(), query)
database.GetTime = time.Now
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
// Add a second auth module for this user
// Have this module's last log-in be more recent
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
query = &login.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err = srv.LookupAndUpdate(context.Background(), query)
database.GetTime = time.Now
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
authInfoStore.ExpectedOAuth.AuthModule = "test2"
// Get the latest entry by not supply an authmodule or authid
getAuthQuery := &login.GetAuthInfoQuery{
UserId: user.ID,
}
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, authInfo.AuthModule, "test2")
// "log in" again with the first auth module
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1"}
err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd)
require.Nil(t, err)
authInfoStore.ExpectedOAuth.AuthModule = "test1"
// Get the latest entry by not supply an authmodule or authid
getAuthQuery = &login.GetAuthInfoQuery{
UserId: user.ID,
}
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, authInfo.AuthModule, "test1")
})
t.Run("Keeps track of last used auth_module when not using oauth", func(t *testing.T) {
// Restore after destructive operation
sqlStore = db.InitTestDB(t)
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
require.NoError(t, err)
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
require.NoError(t, err)
for i := 0; i < 5; i++ {
cmd := user.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
}
_, err := usrSvc.Create(context.Background(), &cmd)
require.Nil(t, err)
}
// Find a user to set tokens on
userlogin := "loginuser0"
fixedTime := time.Now()
// Calling srv.LookupAndUpdateQuery on an existing user will populate an entry in the user_auth table
// Make the first log-in during the past
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -2) }
queryOne := &login.GetUserByAuthInfoQuery{AuthModule: "test1", AuthId: "test1", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err := srv.LookupAndUpdate(context.Background(), queryOne)
database.GetTime = time.Now
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
// Add a second auth module for this user
// Have this module's last log-in be more recent
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, -1) }
queryTwo := &login.GetUserByAuthInfoQuery{AuthModule: "test2", AuthId: "test2", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err = srv.LookupAndUpdate(context.Background(), queryTwo)
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery := &login.GetAuthInfoQuery{
UserId: user.ID,
}
authInfoStore.ExpectedOAuth.AuthModule = "test2"
authInfo, err := authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test2", authInfo.AuthModule)
// Now reuse first auth module and make sure it's updated to the most recent
database.GetTime = func() time.Time { return fixedTime }
// add oauth info to auth_info to make sure update date does not overwrite it
updateAuthCmd := &login.UpdateAuthInfoCommand{UserId: user.ID, AuthModule: "test1", AuthId: "test1", OAuthToken: &oauth2.Token{
AccessToken: "access_token",
TokenType: "token_type",
RefreshToken: "refresh_token",
Expiry: fixedTime,
}}
err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd)
require.Nil(t, err)
user, err = srv.LookupAndUpdate(context.Background(), queryOne)
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
authInfoStore.ExpectedOAuth.AuthModule = "test1"
authInfoStore.ExpectedOAuth.OAuthAccessToken = "access_token"
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test1", authInfo.AuthModule)
// make sure oauth info is not overwritten by update date
require.Equal(t, "access_token", authInfo.OAuthAccessToken)
// Now reuse second auth module and make sure it's updated to the most recent
database.GetTime = func() time.Time { return fixedTime.AddDate(0, 0, 1) }
user, err = srv.LookupAndUpdate(context.Background(), queryTwo)
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
authInfoStore.ExpectedOAuth.AuthModule = "test2"
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err)
require.Equal(t, "test2", authInfo.AuthModule)
// Ensure test 1 did not have its entry modified
getAuthQueryUnchanged := &login.GetAuthInfoQuery{
UserId: user.ID,
AuthModule: "test1",
}
authInfoStore.ExpectedOAuth.AuthModule = "test1"
authInfo, err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
require.Nil(t, err)
require.Equal(t, "test1", authInfo.AuthModule)
})
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
// Find a user to set tokens on
userlogin := "loginuser0"
// Expect to pass since there's a matching login user
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &login.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: login.UserLookupParams{
Login: &userlogin,
}}
user, err := srv.LookupAndUpdate(context.Background(), query)
database.GetTime = time.Now
require.Nil(t, err)
require.Equal(t, user.Login, userlogin)
otherLoginUser := "aloginuser"
// Should throw a "user not found" error since there's no matching login user
database.GetTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query = &login.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: login.UserLookupParams{
Login: &otherLoginUser,
}}
authInfoStore.ExpectedError = errors.New("some error")
user, err = srv.LookupAndUpdate(context.Background(), query)
database.GetTime = time.Now
require.NotNil(t, err)
require.Nil(t, user)
authInfoStore.ExpectedError = nil
})
t.Run("should be able to run loginstats query in all dbs", func(t *testing.T) {
// we need to see that we can run queries for all db
// as it is only a concern for postgres/sqllite3
// where we have duplicate users
// Restore after destructive operation
sqlStore = db.InitTestDB(t)
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
require.NoError(t, err)
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
require.NoError(t, err)
for i := 0; i < 5; i++ {
cmd := user.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
OrgID: 1,
}
_, err := usrSvc.Create(context.Background(), &cmd)
require.Nil(t, err)
}
_, err = srv.authInfoStore.GetLoginStats(context.Background())
require.Nil(t, err)
})
t.Run("calculate metrics on duplicate userstats", func(t *testing.T) {
// Restore after destructive operation
sqlStore = db.InitTestDB(t)
qs := quotaimpl.ProvideService(sqlStore, sqlStore.Cfg)
orgSvc, err := orgimpl.ProvideService(sqlStore, sqlStore.Cfg, qs)
require.NoError(t, err)
usrSvc, err := userimpl.ProvideService(sqlStore, orgSvc, sqlStore.Cfg, nil, nil, qs, supportbundlestest.NewFakeBundleService())
require.NoError(t, err)
for i := 0; i < 5; i++ {
cmd := user.CreateUserCommand{
Email: fmt.Sprint("user", i, "@test.com"),
Name: fmt.Sprint("user", i),
Login: fmt.Sprint("loginuser", i),
OrgID: 1,
}
_, err := usrSvc.Create(context.Background(), &cmd)
require.Nil(t, err)
}
// "Skipping duplicate users test for mysql as it does make unique constraint case insensitive by default
if sqlStore.GetDialect().DriverName() != "mysql" {
dupUserEmailcmd := user.CreateUserCommand{
Email: "USERDUPLICATETEST1@TEST.COM",
Name: "user name 1",
Login: "USER_DUPLICATE_TEST_1_LOGIN",
}
_, err := usrSvc.Create(context.Background(), &dupUserEmailcmd)
require.NoError(t, err)
// add additional user with duplicate login where DOMAIN is upper case
dupUserLogincmd := user.CreateUserCommand{
Email: "userduplicatetest1@test.com",
Name: "user name 1",
Login: "user_duplicate_test_1_login",
}
_, err = usrSvc.Create(context.Background(), &dupUserLogincmd)
require.NoError(t, err)
authInfoStore.ExpectedUser = &user.User{
Email: "userduplicatetest1@test.com",
Name: "user name 1",
Login: "user_duplicate_test_1_login",
}
authInfoStore.ExpectedDuplicateUserEntries = 2
authInfoStore.ExpectedHasDuplicateUserEntries = 1
authInfoStore.ExpectedLoginStats = login.LoginStats{
DuplicateUserEntries: 2,
MixedCasedUsers: 1,
}
// require metrics and statistics to be 2
m, err := srv.authInfoStore.CollectLoginStats(context.Background())
require.NoError(t, err)
require.Equal(t, 2, m["stats.users.duplicate_user_entries"])
require.Equal(t, 1, m["stats.users.has_duplicate_user_entries"])
require.Equal(t, 1, m["stats.users.mixed_cased_users"])
}
})
})
}
type FakeAuthInfoStore struct {
login.AuthInfoService
ExpectedError error
ExpectedUser *user.User
ExpectedOAuth *login.UserAuth
ExpectedDuplicateUserEntries int
ExpectedHasDuplicateUserEntries int
ExpectedLoginStats login.LoginStats
}
func newFakeAuthInfoStore() *FakeAuthInfoStore {
return &FakeAuthInfoStore{}
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
return f.ExpectedOAuth, f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]any, error) {
var res = make(map[string]any)
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
res["stats.users.duplicate_user_entries_by_login"] = 0
res["stats.users.has_duplicate_user_entries_by_login"] = 0
res["stats.users.duplicate_user_entries_by_email"] = 0
res["stats.users.has_duplicate_user_entries_by_email"] = 0
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
return res, f.ExpectedError
}
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
return f.ExpectedLoginStats, f.ExpectedError
}

View File

@ -4,14 +4,12 @@ import (
"context"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
)
type AuthInfoServiceFake struct {
login.AuthInfoService
LatestUserID int64
ExpectedUserAuth *login.UserAuth
ExpectedUser *user.User
ExpectedExternalUser *login.ExternalUserInfo
ExpectedError error
ExpectedLabels map[int64]string
@ -20,15 +18,6 @@ type AuthInfoServiceFake struct {
UpdateAuthInfoFn func(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error
}
func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *login.GetUserByAuthInfoQuery) (*user.User, error) {
if query.UserLookupParams.UserID != nil {
a.LatestUserID = *query.UserLookupParams.UserID
} else {
a.LatestUserID = 0
}
return a.ExpectedUser, a.ExpectedError
}
func (a *AuthInfoServiceFake) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
a.LatestUserID = query.UserId
return a.ExpectedUserAuth, a.ExpectedError

View File

@ -13,7 +13,6 @@ import (
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/login/socialtest"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
@ -230,7 +229,7 @@ func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *social
}
authInfoStore := &FakeAuthInfoStore{}
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
authInfoService := authinfoservice.ProvideAuthInfoService(authInfoStore)
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,