mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Rotate token patch (#62676)
* Use singleflight.Group * Align tests * Cleanup
This commit is contained in:
@@ -5,10 +5,13 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
@@ -41,6 +44,7 @@ func ProvideUserAuthTokenService(sqlStore db.DB,
|
||||
log: log.New("auth"),
|
||||
remoteCache: remoteCache,
|
||||
features: features,
|
||||
singleflight: new(singleflight.Group),
|
||||
}
|
||||
|
||||
defaultLimits, err := readQuotaConfig(cfg)
|
||||
@@ -68,6 +72,7 @@ type UserAuthTokenService struct {
|
||||
log log.Logger
|
||||
remoteCache *remotecache.RemoteCache
|
||||
features *featuremgmt.FeatureManager
|
||||
singleflight *singleflight.Group
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) {
|
||||
@@ -202,6 +207,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
}
|
||||
|
||||
// Current incoming token is the previous auth token in the DB and the auth_token_seen is true
|
||||
if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen {
|
||||
modelCopy := model
|
||||
modelCopy.AuthTokenSeen = false
|
||||
@@ -229,6 +235,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
}
|
||||
|
||||
// Current incoming token is not seen and it is the latest valid auth token in the db
|
||||
if !model.AuthTokenSeen && model.AuthToken == hashedToken {
|
||||
modelCopy := model
|
||||
modelCopy.AuthTokenSeen = true
|
||||
@@ -268,83 +275,102 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken,
|
||||
clientIP net.IP, userAgent string) (bool, error) {
|
||||
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
if token == nil {
|
||||
return false, nil
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
model, err := userAuthTokenFromUserToken(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
now := getTime()
|
||||
|
||||
var needsRotation bool
|
||||
rotatedAt := time.Unix(model.RotatedAt, 0)
|
||||
if model.AuthTokenSeen {
|
||||
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute))
|
||||
} else {
|
||||
needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime))
|
||||
type rotationResult struct {
|
||||
rotated bool
|
||||
newToken *auth.UserToken
|
||||
}
|
||||
|
||||
if !needsRotation {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
ctxLogger := s.log.FromContext(ctx)
|
||||
ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt)
|
||||
|
||||
clientIPStr := clientIP.String()
|
||||
if len(clientIP) == 0 {
|
||||
clientIPStr = ""
|
||||
}
|
||||
newToken, err := util.RandomHex(16)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
hashedToken := hashToken(newToken)
|
||||
|
||||
// very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly
|
||||
sql := `
|
||||
UPDATE user_auth_token
|
||||
SET
|
||||
seen_at = 0,
|
||||
user_agent = ?,
|
||||
client_ip = ?,
|
||||
prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end,
|
||||
auth_token = ?,
|
||||
auth_token_seen = ?,
|
||||
rotated_at = ?
|
||||
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
|
||||
|
||||
var affected int64
|
||||
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
|
||||
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken,
|
||||
s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true),
|
||||
now.Add(-30*time.Second).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
rotResult, err, _ := s.singleflight.Do(fmt.Sprint(model.Id), func() (interface{}, error) {
|
||||
var needsRotation bool
|
||||
rotatedAt := time.Unix(model.RotatedAt, 0)
|
||||
if model.AuthTokenSeen {
|
||||
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute))
|
||||
} else {
|
||||
needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime))
|
||||
}
|
||||
|
||||
affected, err = res.RowsAffected()
|
||||
return err
|
||||
if !needsRotation {
|
||||
return &rotationResult{rotated: false}, nil
|
||||
}
|
||||
|
||||
ctxLogger := s.log.FromContext(ctx)
|
||||
ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt)
|
||||
|
||||
clientIPStr := clientIP.String()
|
||||
if len(clientIP) == 0 {
|
||||
clientIPStr = ""
|
||||
}
|
||||
newToken, err := util.RandomHex(16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashedToken := hashToken(newToken)
|
||||
|
||||
// very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly
|
||||
sql := `
|
||||
UPDATE user_auth_token
|
||||
SET
|
||||
seen_at = 0,
|
||||
user_agent = ?,
|
||||
client_ip = ?,
|
||||
prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end,
|
||||
auth_token = ?,
|
||||
auth_token_seen = ?,
|
||||
rotated_at = ?
|
||||
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
|
||||
|
||||
var affected int64
|
||||
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
|
||||
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken,
|
||||
s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true),
|
||||
now.Add(-30*time.Second).Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err = res.RowsAffected()
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if affected > 0 {
|
||||
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
|
||||
model.UnhashedToken = newToken
|
||||
var result auth.UserToken
|
||||
if err := model.toUserToken(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rotationResult{
|
||||
rotated: true,
|
||||
newToken: &result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &rotationResult{rotated: false}, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
|
||||
if affected > 0 {
|
||||
model.UnhashedToken = newToken
|
||||
if err := model.toUserToken(token); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
result := rotResult.(*rotationResult)
|
||||
|
||||
return false, nil
|
||||
return result.rotated, result.newToken, nil
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
@@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
|
||||
getTime = func() time.Time { return now.Add(time.Hour) }
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.11"), "some user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -262,7 +263,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
prevToken := userToken.AuthToken
|
||||
unhashedPrev := userToken.UnhashedToken
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.False(t, rotated)
|
||||
@@ -280,12 +281,12 @@ func TestUserAuthToken(t *testing.T) {
|
||||
|
||||
getTime = func() time.Time { return now.Add(time.Hour) }
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), &tok,
|
||||
rotated, newToken, err := ctx.tokenService.TryRotateToken(context.Background(), &tok,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
|
||||
unhashedToken := tok.UnhashedToken
|
||||
unhashedToken := newToken.UnhashedToken
|
||||
|
||||
model, err = ctx.getAuthTokenByID(tok.Id)
|
||||
require.Nil(t, err)
|
||||
@@ -326,7 +327,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
require.NotNil(t, lookedUpModel)
|
||||
require.False(t, lookedUpModel.AuthTokenSeen)
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -351,7 +352,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
getTime = func() time.Time { return now.Add(10 * time.Minute) }
|
||||
|
||||
prevToken := userToken.UnhashedToken
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -407,7 +408,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(10 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -429,7 +430,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(20 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -456,7 +457,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(2 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@@ -550,9 +551,10 @@ func createTestContext(t *testing.T) *testContext {
|
||||
}
|
||||
|
||||
tokenService := &UserAuthTokenService{
|
||||
sqlStore: sqlstore,
|
||||
cfg: cfg,
|
||||
log: log.New("test-logger"),
|
||||
sqlStore: sqlstore,
|
||||
cfg: cfg,
|
||||
log: log.New("test-logger"),
|
||||
singleflight: new(singleflight.Group),
|
||||
}
|
||||
|
||||
return &testContext{
|
||||
|
||||
Reference in New Issue
Block a user