Auth: Merge ActiveAuthTokenService into UserAuthTokenService (#59032)

* Auth: Merge UserTokenService and ActiveAuthTokenService

* Auth: Rename function
This commit is contained in:
Karl Persson
2022-11-22 10:58:59 +01:00
committed by GitHub
parent f8f61c1a69
commit 062c5b805c
12 changed files with 70 additions and 110 deletions

View File

@@ -16,10 +16,6 @@ const (
QuotaTarget quota.Target = "session"
)
type ActiveTokenService interface {
ActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error)
}
// Typed errors
var (
ErrUserTokenNotFound = errors.New("user token not found")

View File

@@ -18,40 +18,17 @@ import (
"github.com/grafana/grafana/pkg/util"
)
const ServiceName = "UserAuthTokenService"
const urgentRotateTime = 1 * time.Minute
var getTime = time.Now
const urgentRotateTime = 1 * time.Minute
func ProvideUserAuthTokenService(sqlStore db.DB, serverLockService *serverlock.ServerLockService,
cfg *setting.Cfg) *UserAuthTokenService {
func ProvideUserAuthTokenService(sqlStore db.DB, cfg *setting.Cfg, serverLockService *serverlock.ServerLockService, quotaService quota.Service) (*UserAuthTokenService, error) {
s := &UserAuthTokenService{
SQLStore: sqlStore,
ServerLockService: serverLockService,
Cfg: cfg,
sqlStore: sqlStore,
serverLockService: serverLockService,
cfg: cfg,
log: log.New("auth"),
}
return s
}
type UserAuthTokenService struct {
SQLStore db.DB
ServerLockService *serverlock.ServerLockService
Cfg *setting.Cfg
log log.Logger
}
type ActiveAuthTokenService struct {
cfg *setting.Cfg
sqlStore db.DB
}
func ProvideActiveAuthTokenService(cfg *setting.Cfg, sqlStore db.DB, quotaService quota.Service) (*ActiveAuthTokenService, error) {
s := &ActiveAuthTokenService{
cfg: cfg,
sqlStore: sqlStore,
}
defaultLimits, err := readQuotaConfig(cfg)
if err != nil {
@@ -61,7 +38,7 @@ func ProvideActiveAuthTokenService(cfg *setting.Cfg, sqlStore db.DB, quotaServic
if err := quotaService.RegisterQuotaReporter(&quota.NewUsageReporter{
TargetSrv: auth.QuotaTargetSrv,
DefaultLimits: defaultLimits,
Reporter: s.ActiveTokenCount,
Reporter: s.reportActiveTokenCount,
}); err != nil {
return s, err
}
@@ -69,27 +46,11 @@ func ProvideActiveAuthTokenService(cfg *setting.Cfg, sqlStore db.DB, quotaServic
return s, nil
}
func (a *ActiveAuthTokenService) ActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) {
var count int64
var err error
err = a.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var model userAuthToken
count, err = dbSession.Where(`created_at > ? AND rotated_at > ? AND revoked_at = 0`,
getTime().Add(-a.cfg.LoginMaxLifetime).Unix(),
getTime().Add(-a.cfg.LoginMaxInactiveLifetime).Unix()).
Count(&model)
return err
})
tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope)
if err != nil {
return nil, err
}
u := &quota.Map{}
u.Set(tag, count)
return u, err
type UserAuthTokenService struct {
sqlStore db.DB
serverLockService *serverlock.ServerLockService
cfg *setting.Cfg
log log.Logger
}
func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) {
@@ -120,7 +81,7 @@ func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User,
AuthTokenSeen: false,
}
err = s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
_, err = dbSession.Insert(&userAuthToken)
return err
})
@@ -145,7 +106,7 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st
var model userAuthToken
var exists bool
var err error
err = s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?)",
hashedToken,
hashedToken).
@@ -185,7 +146,7 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st
expireBefore := getTime().Add(-urgentRotateTime).Unix()
var affectedRows int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
affectedRows, err = dbSession.Where("id = ? AND prev_auth_token = ? AND rotated_at < ?",
modelCopy.Id,
modelCopy.PrevAuthToken,
@@ -212,7 +173,7 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st
modelCopy.SeenAt = getTime().Unix()
var affectedRows int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
affectedRows, err = dbSession.Where("id = ? AND auth_token = ?",
modelCopy.Id,
modelCopy.AuthToken).
@@ -260,7 +221,7 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.U
var needsRotation bool
rotatedAt := time.Unix(model.RotatedAt, 0)
if model.AuthTokenSeen {
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.Cfg.TokenRotationIntervalMinutes) * time.Minute))
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute))
} else {
needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime))
}
@@ -296,9 +257,9 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.U
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
var affected int64
err = s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.SQLStore.GetDialect().BooleanStr(true), hashedToken,
s.SQLStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.SQLStore.GetDialect().BooleanStr(true),
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken,
s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true),
now.Add(-30*time.Second).Unix())
if err != nil {
return err
@@ -338,12 +299,12 @@ func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.User
if soft {
model.RevokedAt = getTime().Unix()
err = s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
rowsAffected, err = dbSession.ID(model.Id).Update(model)
return err
})
} else {
err = s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
rowsAffected, err = dbSession.Delete(model)
return err
})
@@ -366,7 +327,7 @@ func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.User
}
func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId int64) error {
return s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
return s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
sql := `DELETE from user_auth_token WHERE user_id = ?`
res, err := dbSession.Exec(sql, userId)
if err != nil {
@@ -385,7 +346,7 @@ func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId i
}
func (s *UserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
return s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
return s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
if len(userIds) == 0 {
return nil
}
@@ -416,7 +377,7 @@ func (s *UserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, use
func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) {
var result auth.UserToken
err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var token userAuthToken
exists, err := dbSession.Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
if err != nil {
@@ -435,7 +396,7 @@ func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTok
func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) {
result := []*auth.UserToken{}
err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var tokens []*userAuthToken
err := dbSession.Where("user_id = ? AND created_at > ? AND rotated_at > ? AND revoked_at = 0",
userId,
@@ -462,7 +423,7 @@ func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64)
func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) {
result := []*auth.UserToken{}
err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var tokens []*userAuthToken
err := dbSession.Where("user_id = ? AND revoked_at > 0", userId).Find(&tokens)
if err != nil {
@@ -483,12 +444,35 @@ func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId
return result, err
}
func (s *UserAuthTokenService) reportActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) {
var count int64
var err error
err = s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
var model userAuthToken
count, err = dbSession.Where(`created_at > ? AND rotated_at > ? AND revoked_at = 0`,
getTime().Add(-s.cfg.LoginMaxLifetime).Unix(),
getTime().Add(-s.cfg.LoginMaxInactiveLifetime).Unix()).
Count(&model)
return err
})
tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope)
if err != nil {
return nil, err
}
u := &quota.Map{}
u.Set(tag, count)
return u, err
}
func (s *UserAuthTokenService) createdAfterParam() int64 {
return getTime().Add(-s.Cfg.LoginMaxLifetime).Unix()
return getTime().Add(-s.cfg.LoginMaxLifetime).Unix()
}
func (s *UserAuthTokenService) rotatedAfterParam() int64 {
return getTime().Add(-s.Cfg.LoginMaxInactiveLifetime).Unix()
return getTime().Add(-s.cfg.LoginMaxInactiveLifetime).Unix()
}
func hashToken(token string) string {

View File

@@ -41,7 +41,7 @@ func TestUserAuthToken(t *testing.T) {
userToken := createToken()
t.Run("Can count active tokens", func(t *testing.T) {
m, err := ctx.activeTokenService.ActiveTokenCount(context.Background(), &quota.ScopeParameters{})
m, err := ctx.tokenService.reportActiveTokenCount(context.Background(), &quota.ScopeParameters{})
require.Nil(t, err)
tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope)
require.NoError(t, err)
@@ -213,7 +213,7 @@ func TestUserAuthToken(t *testing.T) {
require.Nil(t, notGood)
t.Run("should not find active token when expired", func(t *testing.T) {
m, err := ctx.activeTokenService.ActiveTokenCount(context.Background(), &quota.ScopeParameters{})
m, err := ctx.tokenService.reportActiveTokenCount(context.Background(), &quota.ScopeParameters{})
require.Nil(t, err)
tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope)
require.NoError(t, err)
@@ -550,27 +550,20 @@ func createTestContext(t *testing.T) *testContext {
}
tokenService := &UserAuthTokenService{
SQLStore: sqlstore,
Cfg: cfg,
sqlStore: sqlstore,
cfg: cfg,
log: log.New("test-logger"),
}
activeTokenService := &ActiveAuthTokenService{
cfg: cfg,
sqlStore: sqlstore,
}
return &testContext{
sqlstore: sqlstore,
tokenService: tokenService,
activeTokenService: activeTokenService,
sqlstore: sqlstore,
tokenService: tokenService,
}
}
type testContext struct {
sqlstore db.DB
tokenService *UserAuthTokenService
activeTokenService *ActiveAuthTokenService
sqlstore db.DB
tokenService *UserAuthTokenService
}
func (c *testContext) getAuthTokenByID(id int64) (*userAuthToken, error) {

View File

@@ -9,10 +9,10 @@ import (
func (s *UserAuthTokenService) Run(ctx context.Context) error {
ticker := time.NewTicker(time.Hour)
maxInactiveLifetime := s.Cfg.LoginMaxInactiveLifetime
maxLifetime := s.Cfg.LoginMaxLifetime
maxInactiveLifetime := s.cfg.LoginMaxInactiveLifetime
maxLifetime := s.cfg.LoginMaxLifetime
err := s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
err := s.serverLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil {
s.log.Error("An error occurred while deleting expired tokens", "err", err)
}
@@ -24,7 +24,7 @@ func (s *UserAuthTokenService) Run(ctx context.Context) error {
for {
select {
case <-ticker.C:
err = s.ServerLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
err = s.serverLockService.LockAndExecute(ctx, "cleanup expired auth tokens", time.Hour*12, func(context.Context) {
if _, err := s.deleteExpiredTokens(ctx, maxInactiveLifetime, maxLifetime); err != nil {
s.log.Error("An error occurred while deleting expired tokens", "err", err)
}
@@ -46,7 +46,7 @@ func (s *UserAuthTokenService) deleteExpiredTokens(ctx context.Context, maxInact
s.log.Debug("starting cleanup of expired auth tokens", "createdBefore", createdBefore, "rotatedBefore", rotatedBefore)
var affected int64
err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error {
err := s.sqlStore.WithDbSession(ctx, func(dbSession *db.Session) error {
sql := `DELETE from user_auth_token WHERE created_at <= ? OR rotated_at <= ?`
res, err := dbSession.Exec(sql, createdBefore.Unix(), rotatedBefore.Unix())
if err != nil {

View File

@@ -16,8 +16,8 @@ func TestUserAuthTokenCleanup(t *testing.T) {
ctx := createTestContext(t)
maxInactiveLifetime, _ := time.ParseDuration("168h")
maxLifetime, _ := time.ParseDuration("720h")
ctx.tokenService.Cfg.LoginMaxInactiveLifetime = maxInactiveLifetime
ctx.tokenService.Cfg.LoginMaxLifetime = maxLifetime
ctx.tokenService.cfg.LoginMaxInactiveLifetime = maxInactiveLifetime
ctx.tokenService.cfg.LoginMaxLifetime = maxLifetime
return ctx
}