From 77e4d477e579ce1204e5a2967e1051f880398004 Mon Sep 17 00:00:00 2001 From: Jo Date: Mon, 11 Sep 2023 10:24:57 +0200 Subject: [PATCH] Auth: Optimize auth token operations (#74602) * add token count * wip * user count method for tag reporting * remove non functioning mysql clientFoundRows check * Update pkg/services/auth/authtest/testing.go Co-authored-by: Misi * add user ID guard --------- Co-authored-by: Misi --- pkg/services/auth/auth.go | 9 +++-- pkg/services/auth/authimpl/auth_token.go | 38 ++++++++++++------ pkg/services/auth/authimpl/auth_token_test.go | 39 +++++++++++++++++-- pkg/services/auth/authtest/testing.go | 18 ++++----- 4 files changed, 77 insertions(+), 27 deletions(-) diff --git a/pkg/services/auth/auth.go b/pkg/services/auth/auth.go index 9c163f63c72..4577f19994e 100644 --- a/pkg/services/auth/auth.go +++ b/pkg/services/auth/auth.go @@ -73,10 +73,11 @@ type UserTokenService interface { RotateToken(ctx context.Context, cmd RotateCommand) (*UserToken, error) TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error) RevokeToken(ctx context.Context, token *UserToken, soft bool) error - RevokeAllUserTokens(ctx context.Context, userId int64) error - GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error) - GetUserTokens(ctx context.Context, userId int64) ([]*UserToken, error) - GetUserRevokedTokens(ctx context.Context, userId int64) ([]*UserToken, error) + RevokeAllUserTokens(ctx context.Context, userID int64) error + GetUserToken(ctx context.Context, userID, userTokenID int64) (*UserToken, error) + GetUserTokens(ctx context.Context, userID int64) ([]*UserToken, error) + ActiveTokenCount(ctx context.Context, userID *int64) (int64, error) + GetUserRevokedTokens(ctx context.Context, userID int64) ([]*UserToken, error) } type UserTokenBackgroundService interface { diff --git a/pkg/services/auth/authimpl/auth_token.go b/pkg/services/auth/authimpl/auth_token.go index a2ee6f2cfba..b2ae8efc32d 100644 --- a/pkg/services/auth/authimpl/auth_token.go +++ b/pkg/services/auth/authimpl/auth_token.go @@ -26,6 +26,7 @@ import ( var ( getTime = time.Now errTokenNotRotated = errors.New("token was not rotated") + errUserIDInvalid = errors.New("invalid user ID") ) func ProvideUserAuthTokenService(sqlStore db.DB, @@ -529,6 +530,27 @@ func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) return result, err } +// ActiveTokenCount returns the number of active tokens. If userID is nil, the count is for all users. +func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context, userID *int64) (int64, error) { + if userID != nil && *userID < 1 { + return 0, errUserIDInvalid + } + + var count int64 + err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error { + query := `SELECT COUNT(*) FROM user_auth_token WHERE created_at > ? AND rotated_at > ? AND revoked_at = 0` + args := []interface{}{s.createdAfterParam(), s.rotatedAfterParam()} + if userID != nil { + query += " AND user_id = ?" + args = append(args, *userID) + } + _, err := dbSession.SQL(query, args...).Get(&count) + return err + }) + + return count, err +} + func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) { result := []*auth.UserToken{} err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error { @@ -553,22 +575,16 @@ func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId } func (s *UserAuthTokenService) reportActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) { - var count int64 - var err error - err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error { - var model userAuthToken - count, err = dbSession.Where(`created_at > ? AND rotated_at > ? AND revoked_at = 0`, - getTime().Add(-s.cfg.LoginMaxLifetime).Unix(), - getTime().Add(-s.cfg.LoginMaxInactiveLifetime).Unix()). - Count(&model) - - return err - }) + count, err := s.ActiveTokenCount(ctx, nil) + if err != nil { + return nil, err + } tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope) if err != nil { return nil, err } + u := "a.Map{} u.Set(tag, count) diff --git a/pkg/services/auth/authimpl/auth_token_test.go b/pkg/services/auth/authimpl/auth_token_test.go index 0393d22f134..3568ae29816 100644 --- a/pkg/services/auth/authimpl/auth_token_test.go +++ b/pkg/services/auth/authimpl/auth_token_test.go @@ -21,7 +21,7 @@ import ( "github.com/grafana/grafana/pkg/setting" ) -func TestUserAuthToken(t *testing.T) { +func TestIntegrationUserAuthToken(t *testing.T) { ctx := createTestContext(t) user := &user.User{ID: int64(10)} // userID := user.Id @@ -240,9 +240,8 @@ func TestUserAuthToken(t *testing.T) { }) t.Run("when rotated_at is 5 days ago and created_at is 30 days ago should return token expired error", func(t *testing.T) { - updated, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix()) + _, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix()) require.Nil(t, err) - require.True(t, updated) getTime = func() time.Time { return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour) @@ -674,3 +673,37 @@ func (c *testContext) updateRotatedAt(id, rotatedAt int64) (bool, error) { }) return hasRowsAffected, err } + +func TestIntegrationTokenCount(t *testing.T) { + ctx := createTestContext(t) + user := &user.User{ID: int64(10)} + + createToken := func() *auth.UserToken { + userToken, err := ctx.tokenService.CreateToken(context.Background(), user, + net.ParseIP("192.168.10.11"), "some user agent") + require.Nil(t, err) + require.NotNil(t, userToken) + require.False(t, userToken.AuthTokenSeen) + return userToken + } + + createToken() + + now := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC) + getTime = func() time.Time { return now } + defer func() { getTime = time.Now }() + + count, err := ctx.tokenService.ActiveTokenCount(context.Background(), nil) + require.Nil(t, err) + require.Equal(t, int64(1), count) + + var userID int64 = 10 + count, err = ctx.tokenService.ActiveTokenCount(context.Background(), &userID) + require.Nil(t, err) + require.Equal(t, int64(1), count) + + userID = 11 + count, err = ctx.tokenService.ActiveTokenCount(context.Background(), &userID) + require.Nil(t, err) + require.Equal(t, int64(0), count) +} diff --git a/pkg/services/auth/authtest/testing.go b/pkg/services/auth/authtest/testing.go index dd154d625e9..75682224aa6 100644 --- a/pkg/services/auth/authtest/testing.go +++ b/pkg/services/auth/authtest/testing.go @@ -20,12 +20,12 @@ type FakeUserAuthTokenService struct { TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error - RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error - ActiveAuthTokenCount func(ctx context.Context) (int64, error) - GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) - GetUserTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error) - GetUserRevokedTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error) - BatchRevokedTokenProvider func(ctx context.Context, userIds []int64) error + RevokeAllUserTokensProvider func(ctx context.Context, userID int64) error + ActiveTokenCountProvider func(ctx context.Context, userID *int64) (int64, error) + GetUserTokenProvider func(ctx context.Context, userID, userTokenID int64) (*auth.UserToken, error) + GetUserTokensProvider func(ctx context.Context, userID int64) ([]*auth.UserToken, error) + GetUserRevokedTokensProvider func(ctx context.Context, userID int64) ([]*auth.UserToken, error) + BatchRevokedTokenProvider func(ctx context.Context, userIDs []int64) error } func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { @@ -54,7 +54,7 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { BatchRevokedTokenProvider: func(ctx context.Context, userIds []int64) error { return nil }, - ActiveAuthTokenCount: func(ctx context.Context) (int64, error) { + ActiveTokenCountProvider: func(ctx context.Context, userID *int64) (int64, error) { return 10, nil }, GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { @@ -97,8 +97,8 @@ func (s *FakeUserAuthTokenService) RevokeAllUserTokens(ctx context.Context, user return s.RevokeAllUserTokensProvider(context.Background(), userId) } -func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) { - return s.ActiveAuthTokenCount(context.Background()) +func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context, userID *int64) (int64, error) { + return s.ActiveTokenCountProvider(context.Background(), userID) } func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) {