Auth: Rotate token patch (#62676)

* Use singleflight.Group

* Align tests

* Cleanup
This commit is contained in:
Misi 2023-02-02 14:36:16 +01:00 committed by GitHub
parent 3c01ae2c9e
commit 7c1d9769ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 122 additions and 88 deletions

View File

@ -13,10 +13,11 @@ import (
"testing"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/infra/db/dbtest"
"github.com/grafana/grafana/pkg/infra/fs"
@ -344,9 +345,9 @@ func TestMiddlewareContext(t *testing.T) {
}
sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken,
clientIP net.IP, userAgent string) (bool, error) {
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
userToken.UnhashedToken = "rotated"
return true, nil
return true, userToken, nil
}
maxAge := int(sc.cfg.LoginMaxLifetime.Seconds())

View File

@ -62,7 +62,7 @@ type RevokeAuthTokenCmd struct {
type UserTokenService interface {
CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*UserToken, error)
LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error)
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, error)
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error)
RevokeToken(ctx context.Context, token *UserToken, soft bool) error
RevokeAllUserTokens(ctx context.Context, userId int64) error
GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error)

View File

@ -5,10 +5,13 @@ import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net"
"strings"
"time"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
@ -41,6 +44,7 @@ func ProvideUserAuthTokenService(sqlStore db.DB,
log: log.New("auth"),
remoteCache: remoteCache,
features: features,
singleflight: new(singleflight.Group),
}
defaultLimits, err := readQuotaConfig(cfg)
@ -68,6 +72,7 @@ type UserAuthTokenService struct {
log log.Logger
remoteCache *remotecache.RemoteCache
features *featuremgmt.FeatureManager
singleflight *singleflight.Group
}
func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) {
@ -202,6 +207,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
}
}
// Current incoming token is the previous auth token in the DB and the auth_token_seen is true
if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen {
modelCopy := model
modelCopy.AuthTokenSeen = false
@ -229,6 +235,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
}
}
// Current incoming token is not seen and it is the latest valid auth token in the db
if !model.AuthTokenSeen && model.AuthToken == hashedToken {
modelCopy := model
modelCopy.AuthTokenSeen = true
@ -268,83 +275,102 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
}
func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken,
clientIP net.IP, userAgent string) (bool, error) {
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
if token == nil {
return false, nil
return false, nil, nil
}
model, err := userAuthTokenFromUserToken(token)
if err != nil {
return false, err
return false, nil, err
}
now := getTime()
var needsRotation bool
rotatedAt := time.Unix(model.RotatedAt, 0)
if model.AuthTokenSeen {
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute))
} else {
needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime))
type rotationResult struct {
rotated bool
newToken *auth.UserToken
}
if !needsRotation {
return false, nil
}
ctxLogger := s.log.FromContext(ctx)
ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt)
clientIPStr := clientIP.String()
if len(clientIP) == 0 {
clientIPStr = ""
}
newToken, err := util.RandomHex(16)
if err != nil {
return false, err
}
hashedToken := hashToken(newToken)
// very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly
sql := `
UPDATE user_auth_token
SET
seen_at = 0,
user_agent = ?,
client_ip = ?,
prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end,
auth_token = ?,
auth_token_seen = ?,
rotated_at = ?
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
var affected int64
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken,
s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true),
now.Add(-30*time.Second).Unix())
if err != nil {
return err
rotResult, err, _ := s.singleflight.Do(fmt.Sprint(model.Id), func() (interface{}, error) {
var needsRotation bool
rotatedAt := time.Unix(model.RotatedAt, 0)
if model.AuthTokenSeen {
needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute))
} else {
needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime))
}
affected, err = res.RowsAffected()
return err
if !needsRotation {
return &rotationResult{rotated: false}, nil
}
ctxLogger := s.log.FromContext(ctx)
ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt)
clientIPStr := clientIP.String()
if len(clientIP) == 0 {
clientIPStr = ""
}
newToken, err := util.RandomHex(16)
if err != nil {
return nil, err
}
hashedToken := hashToken(newToken)
// very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly
sql := `
UPDATE user_auth_token
SET
seen_at = 0,
user_agent = ?,
client_ip = ?,
prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end,
auth_token = ?,
auth_token_seen = ?,
rotated_at = ?
WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)`
var affected int64
err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error {
res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken,
s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true),
now.Add(-30*time.Second).Unix())
if err != nil {
return err
}
affected, err = res.RowsAffected()
return err
})
if err != nil {
return nil, err
}
if affected > 0 {
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
model.UnhashedToken = newToken
var result auth.UserToken
if err := model.toUserToken(&result); err != nil {
return nil, err
}
return &rotationResult{
rotated: true,
newToken: &result,
}, nil
}
return &rotationResult{rotated: false}, nil
})
if err != nil {
return false, err
return false, nil, err
}
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
if affected > 0 {
model.UnhashedToken = newToken
if err := model.toUserToken(token); err != nil {
return false, err
}
return true, nil
}
result := rotResult.(*rotationResult)
return false, nil
return result.rotated, result.newToken, nil
}
func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error {

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/db"
@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) {
getTime = func() time.Time { return now.Add(time.Hour) }
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("192.168.10.11"), "some user agent")
require.Nil(t, err)
require.True(t, rotated)
@ -262,7 +263,7 @@ func TestUserAuthToken(t *testing.T) {
prevToken := userToken.AuthToken
unhashedPrev := userToken.UnhashedToken
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("192.168.10.12"), "a new user agent")
require.Nil(t, err)
require.False(t, rotated)
@ -280,12 +281,12 @@ func TestUserAuthToken(t *testing.T) {
getTime = func() time.Time { return now.Add(time.Hour) }
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), &tok,
rotated, newToken, err := ctx.tokenService.TryRotateToken(context.Background(), &tok,
net.ParseIP("192.168.10.12"), "a new user agent")
require.Nil(t, err)
require.True(t, rotated)
unhashedToken := tok.UnhashedToken
unhashedToken := newToken.UnhashedToken
model, err = ctx.getAuthTokenByID(tok.Id)
require.Nil(t, err)
@ -326,7 +327,7 @@ func TestUserAuthToken(t *testing.T) {
require.NotNil(t, lookedUpModel)
require.False(t, lookedUpModel.AuthTokenSeen)
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("192.168.10.12"), "a new user agent")
require.Nil(t, err)
require.True(t, rotated)
@ -351,7 +352,7 @@ func TestUserAuthToken(t *testing.T) {
getTime = func() time.Time { return now.Add(10 * time.Minute) }
prevToken := userToken.UnhashedToken
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("1.1.1.1"), "firefox")
require.Nil(t, err)
require.True(t, rotated)
@ -407,7 +408,7 @@ func TestUserAuthToken(t *testing.T) {
return now.Add(10 * time.Minute)
}
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("1.1.1.1"), "firefox")
require.Nil(t, err)
require.True(t, rotated)
@ -429,7 +430,7 @@ func TestUserAuthToken(t *testing.T) {
return now.Add(20 * time.Minute)
}
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("1.1.1.1"), "firefox")
require.Nil(t, err)
require.True(t, rotated)
@ -456,7 +457,7 @@ func TestUserAuthToken(t *testing.T) {
return now.Add(2 * time.Minute)
}
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
net.ParseIP("1.1.1.1"), "firefox")
require.Nil(t, err)
require.True(t, rotated)
@ -550,9 +551,10 @@ func createTestContext(t *testing.T) *testContext {
}
tokenService := &UserAuthTokenService{
sqlStore: sqlstore,
cfg: cfg,
log: log.New("test-logger"),
sqlStore: sqlstore,
cfg: cfg,
log: log.New("test-logger"),
singleflight: new(singleflight.Group),
}
return &testContext{

View File

@ -15,7 +15,7 @@ import (
type FakeUserAuthTokenService struct {
CreateTokenProvider func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error)
TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error)
TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error)
LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error)
RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error
RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error
@ -34,8 +34,8 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
UnhashedToken: "",
}, nil
},
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) {
return false, nil
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
return false, nil, nil
},
LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
@ -79,7 +79,7 @@ func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToke
}
func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
userAgent string) (bool, *auth.UserToken, error) {
return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent)
}

View File

@ -107,13 +107,14 @@ func (s *Session) RefreshTokenHook(ctx context.Context, identity *authn.Identity
s.log.Debug("failed to get client IP address", "addr", addr, "err", err)
ip = nil
}
rotated, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent)
rotated, newToken, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent)
if err != nil {
s.log.Error("failed to rotate token", "error", err)
return
}
if rotated {
identity.SessionToken = newToken
s.log.Debug("rotated session token", "user", identity.ID)
maxAge := int(s.loginMaxLifetime.Seconds())

View File

@ -143,9 +143,9 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) {
func TestSession_RefreshHook(t *testing.T) {
s := ProvideSession(&authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) {
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
token.UnhashedToken = "new-token"
return true, nil
return true, token, nil
},
}, &usertest.FakeUserService{}, "grafana-session", 20*time.Second)

View File

@ -11,6 +11,8 @@ import (
"strings"
"time"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/components/apikeygen"
apikeygenprefix "github.com/grafana/grafana/pkg/components/apikeygenprefixed"
"github.com/grafana/grafana/pkg/infra/db"
@ -70,6 +72,7 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer
oauthTokenService: oauthTokenService,
features: features,
authnService: authnService,
singleflight: new(singleflight.Group),
}
}
@ -91,6 +94,7 @@ type ContextHandler struct {
oauthTokenService oauthtoken.OAuthTokenService
features *featuremgmt.FeatureManager
authnService authn.Service
singleflight *singleflight.Group
// GetTime returns the current time.
// Stubbable by tests.
GetTime func() time.Time
@ -568,15 +572,15 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *contextmodel.ReqCont
ip = nil
}
// FIXME (jguer): rotation should return a new token instead of modifying the existing one.
rotated, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
rotated, newToken, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
if err != nil {
reqContext.Logger.Error("Failed to rotate token", "error", err)
return
}
if rotated {
cookies.WriteSessionCookie(reqContext, h.Cfg, reqContext.UserToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
reqContext.UserToken = newToken
cookies.WriteSessionCookie(reqContext, h.Cfg, newToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
}
}
}

View File

@ -24,9 +24,9 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
tryRotateCallCount := 0
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
userAgent string) (bool, *auth.UserToken, error) {
tryRotateCallCount++
return false, nil
return false, nil, nil
},
}
@ -46,11 +46,11 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
ctxHdlr := getContextHandler(t)
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
userAgent string) (bool, *auth.UserToken, error) {
newToken, err := util.RandomHex(16)
require.NoError(t, err)
token.AuthToken = newToken
return true, nil
return true, token, nil
},
}