Auth: Optimize auth token operations (#74602)

* add token count

* wip

* user count method for tag reporting

* remove non functioning mysql clientFoundRows check

* Update pkg/services/auth/authtest/testing.go

Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>

* add user ID guard

---------

Co-authored-by: Misi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
Jo 2023-09-11 10:24:57 +02:00 committed by GitHub
parent 09137da78c
commit 77e4d477e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 27 deletions

View File

@ -73,10 +73,11 @@ type UserTokenService interface {
RotateToken(ctx context.Context, cmd RotateCommand) (*UserToken, error)
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error)
RevokeToken(ctx context.Context, token *UserToken, soft bool) error
RevokeAllUserTokens(ctx context.Context, userId int64) error
GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error)
GetUserTokens(ctx context.Context, userId int64) ([]*UserToken, error)
GetUserRevokedTokens(ctx context.Context, userId int64) ([]*UserToken, error)
RevokeAllUserTokens(ctx context.Context, userID int64) error
GetUserToken(ctx context.Context, userID, userTokenID int64) (*UserToken, error)
GetUserTokens(ctx context.Context, userID int64) ([]*UserToken, error)
ActiveTokenCount(ctx context.Context, userID *int64) (int64, error)
GetUserRevokedTokens(ctx context.Context, userID int64) ([]*UserToken, error)
}
type UserTokenBackgroundService interface {

View File

@ -26,6 +26,7 @@ import (
var (
getTime = time.Now
errTokenNotRotated = errors.New("token was not rotated")
errUserIDInvalid = errors.New("invalid user ID")
)
func ProvideUserAuthTokenService(sqlStore db.DB,
@ -529,6 +530,27 @@ func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64)
return result, err
}
// ActiveTokenCount returns the number of active tokens. If userID is nil, the count is for all users.
func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context, userID *int64) (int64, error) {
if userID != nil && *userID < 1 {
return 0, errUserIDInvalid
}
var count int64
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
query := `SELECT COUNT(*) FROM user_auth_token WHERE created_at > ? AND rotated_at > ? AND revoked_at = 0`
args := []interface{}{s.createdAfterParam(), s.rotatedAfterParam()}
if userID != nil {
query += " AND user_id = ?"
args = append(args, *userID)
}
_, err := dbSession.SQL(query, args...).Get(&count)
return err
})
return count, err
}
func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) {
result := []*auth.UserToken{}
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
@ -553,22 +575,16 @@ func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId
}
func (s *UserAuthTokenService) reportActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) {
var count int64
var err error
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var model userAuthToken
count, err = dbSession.Where(`created_at > ? AND rotated_at > ? AND revoked_at = 0`,
getTime().Add(-s.cfg.LoginMaxLifetime).Unix(),
getTime().Add(-s.cfg.LoginMaxInactiveLifetime).Unix()).
Count(&model)
return err
})
count, err := s.ActiveTokenCount(ctx, nil)
if err != nil {
return nil, err
}
tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope)
if err != nil {
return nil, err
}
u := &quota.Map{}
u.Set(tag, count)

View File

@ -21,7 +21,7 @@ import (
"github.com/grafana/grafana/pkg/setting"
)
func TestUserAuthToken(t *testing.T) {
func TestIntegrationUserAuthToken(t *testing.T) {
ctx := createTestContext(t)
user := &user.User{ID: int64(10)}
// userID := user.Id
@ -240,9 +240,8 @@ func TestUserAuthToken(t *testing.T) {
})
t.Run("when rotated_at is 5 days ago and created_at is 30 days ago should return token expired error", func(t *testing.T) {
updated, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix())
_, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix())
require.Nil(t, err)
require.True(t, updated)
getTime = func() time.Time {
return time.Unix(model.CreatedAt, 0).Add(24 * 30 * time.Hour)
@ -674,3 +673,37 @@ func (c *testContext) updateRotatedAt(id, rotatedAt int64) (bool, error) {
})
return hasRowsAffected, err
}
func TestIntegrationTokenCount(t *testing.T) {
ctx := createTestContext(t)
user := &user.User{ID: int64(10)}
createToken := func() *auth.UserToken {
userToken, err := ctx.tokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent")
require.Nil(t, err)
require.NotNil(t, userToken)
require.False(t, userToken.AuthTokenSeen)
return userToken
}
createToken()
now := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC)
getTime = func() time.Time { return now }
defer func() { getTime = time.Now }()
count, err := ctx.tokenService.ActiveTokenCount(context.Background(), nil)
require.Nil(t, err)
require.Equal(t, int64(1), count)
var userID int64 = 10
count, err = ctx.tokenService.ActiveTokenCount(context.Background(), &userID)
require.Nil(t, err)
require.Equal(t, int64(1), count)
userID = 11
count, err = ctx.tokenService.ActiveTokenCount(context.Background(), &userID)
require.Nil(t, err)
require.Equal(t, int64(0), count)
}

View File

@ -20,12 +20,12 @@ type FakeUserAuthTokenService struct {
TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error)
LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error)
RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error
RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error
ActiveAuthTokenCount func(ctx context.Context) (int64, error)
GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error)
GetUserTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error)
GetUserRevokedTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error)
BatchRevokedTokenProvider func(ctx context.Context, userIds []int64) error
RevokeAllUserTokensProvider func(ctx context.Context, userID int64) error
ActiveTokenCountProvider func(ctx context.Context, userID *int64) (int64, error)
GetUserTokenProvider func(ctx context.Context, userID, userTokenID int64) (*auth.UserToken, error)
GetUserTokensProvider func(ctx context.Context, userID int64) ([]*auth.UserToken, error)
GetUserRevokedTokensProvider func(ctx context.Context, userID int64) ([]*auth.UserToken, error)
BatchRevokedTokenProvider func(ctx context.Context, userIDs []int64) error
}
func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
@ -54,7 +54,7 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
BatchRevokedTokenProvider: func(ctx context.Context, userIds []int64) error {
return nil
},
ActiveAuthTokenCount: func(ctx context.Context) (int64, error) {
ActiveTokenCountProvider: func(ctx context.Context, userID *int64) (int64, error) {
return 10, nil
},
GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) {
@ -97,8 +97,8 @@ func (s *FakeUserAuthTokenService) RevokeAllUserTokens(ctx context.Context, user
return s.RevokeAllUserTokensProvider(context.Background(), userId)
}
func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, error) {
return s.ActiveAuthTokenCount(context.Background())
func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context, userID *int64) (int64, error) {
return s.ActiveTokenCountProvider(context.Background(), userID)
}
func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) {