mirror of
https://github.com/grafana/grafana.git
synced 2025-02-11 08:05:43 -06:00
Auth: Rotate token patch (#62676)
* Use singleflight.Group * Align tests * Cleanup
This commit is contained in:
parent
3c01ae2c9e
commit
7c1d9769ca
@ -13,10 +13,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/dtos"
|
||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
||||
"github.com/grafana/grafana/pkg/infra/fs"
|
||||
@ -344,9 +345,9 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
}
|
||||
|
||||
sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken,
|
||||
clientIP net.IP, userAgent string) (bool, error) {
|
||||
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
userToken.UnhashedToken = "rotated"
|
||||
return true, nil
|
||||
return true, userToken, nil
|
||||
}
|
||||
|
||||
maxAge := int(sc.cfg.LoginMaxLifetime.Seconds())
|
||||
|
@ -62,7 +62,7 @@ type RevokeAuthTokenCmd struct {
|
||||
type UserTokenService interface {
|
||||
CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*UserToken, error)
|
||||
LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error)
|
||||
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, error)
|
||||
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error)
|
||||
RevokeToken(ctx context.Context, token *UserToken, soft bool) error
|
||||
RevokeAllUserTokens(ctx context.Context, userId int64) error
|
||||
GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error)
|
||||
|
@ -5,10 +5,13 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
@ -41,6 +44,7 @@ func ProvideUserAuthTokenService(sqlStore db.DB,
|
||||
log: log.New("auth"),
|
||||
remoteCache: remoteCache,
|
||||
features: features,
|
||||
singleflight: new(singleflight.Group),
|
||||
}
|
||||
|
||||
defaultLimits, err := readQuotaConfig(cfg)
|
||||
@ -68,6 +72,7 @@ type UserAuthTokenService struct {
|
||||
log log.Logger
|
||||
remoteCache *remotecache.RemoteCache
|
||||
features *featuremgmt.FeatureManager
|
||||
singleflight *singleflight.Group
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) {
|
||||
@ -202,6 +207,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
}
|
||||
|
||||
// Current incoming token is the previous auth token in the DB and the auth_token_seen is true
|
||||
if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen {
|
||||
modelCopy := model
|
||||
modelCopy.AuthTokenSeen = false
|
||||
@ -229,6 +235,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
}
|
||||
|
||||
// Current incoming token is not seen and it is the latest valid auth token in the db
|
||||
if !model.AuthTokenSeen && model.AuthToken == hashedToken {
|
||||
modelCopy := model
|
||||
modelCopy.AuthTokenSeen = true
|
||||
@ -268,18 +275,24 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken,
|
||||
clientIP net.IP, userAgent string) (bool, error) {
|
||||
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
if token == nil {
|
||||
return false, nil
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
model, err := userAuthTokenFromUserToken(token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
now := getTime()
|
||||
|
||||
type rotationResult struct {
|
||||
rotated bool
|
||||
newToken *auth.UserToken
|
||||
}
|
||||
|
||||
rotResult, err, _ := s.singleflight.Do(fmt.Sprint(model.Id), func() (interface{}, error) {
|
||||
var needsRotation bool
|
||||
rotatedAt := time.Unix(model.RotatedAt, 0)
|
||||
if model.AuthTokenSeen {
|
||||
@ -289,7 +302,7 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.U
|
||||
}
|
||||
|
||||
if !needsRotation {
|
||||
return false, nil
|
||||
return &rotationResult{rotated: false}, nil
|
||||
}
|
||||
|
||||
ctxLogger := s.log.FromContext(ctx)
|
||||
@ -301,7 +314,7 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.U
|
||||
}
|
||||
newToken, err := util.RandomHex(16)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, err
|
||||
}
|
||||
hashedToken := hashToken(newToken)
|
||||
|
||||
@ -332,19 +345,32 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.U
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
|
||||
if affected > 0 {
|
||||
ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId)
|
||||
model.UnhashedToken = newToken
|
||||
if err := model.toUserToken(token); err != nil {
|
||||
return false, err
|
||||
var result auth.UserToken
|
||||
if err := model.toUserToken(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return true, nil
|
||||
return &rotationResult{
|
||||
rotated: true,
|
||||
newToken: &result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return &rotationResult{rotated: false}, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
result := rotResult.(*rotationResult)
|
||||
|
||||
return result.rotated, result.newToken, nil
|
||||
}
|
||||
|
||||
func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error {
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
|
||||
getTime = func() time.Time { return now.Add(time.Hour) }
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.11"), "some user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -262,7 +263,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
prevToken := userToken.AuthToken
|
||||
unhashedPrev := userToken.UnhashedToken
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.False(t, rotated)
|
||||
@ -280,12 +281,12 @@ func TestUserAuthToken(t *testing.T) {
|
||||
|
||||
getTime = func() time.Time { return now.Add(time.Hour) }
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), &tok,
|
||||
rotated, newToken, err := ctx.tokenService.TryRotateToken(context.Background(), &tok,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
|
||||
unhashedToken := tok.UnhashedToken
|
||||
unhashedToken := newToken.UnhashedToken
|
||||
|
||||
model, err = ctx.getAuthTokenByID(tok.Id)
|
||||
require.Nil(t, err)
|
||||
@ -326,7 +327,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
require.NotNil(t, lookedUpModel)
|
||||
require.False(t, lookedUpModel.AuthTokenSeen)
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("192.168.10.12"), "a new user agent")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -351,7 +352,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
getTime = func() time.Time { return now.Add(10 * time.Minute) }
|
||||
|
||||
prevToken := userToken.UnhashedToken
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -407,7 +408,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(10 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -429,7 +430,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(20 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -456,7 +457,7 @@ func TestUserAuthToken(t *testing.T) {
|
||||
return now.Add(2 * time.Minute)
|
||||
}
|
||||
|
||||
rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken,
|
||||
net.ParseIP("1.1.1.1"), "firefox")
|
||||
require.Nil(t, err)
|
||||
require.True(t, rotated)
|
||||
@ -553,6 +554,7 @@ func createTestContext(t *testing.T) *testContext {
|
||||
sqlStore: sqlstore,
|
||||
cfg: cfg,
|
||||
log: log.New("test-logger"),
|
||||
singleflight: new(singleflight.Group),
|
||||
}
|
||||
|
||||
return &testContext{
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
|
||||
type FakeUserAuthTokenService struct {
|
||||
CreateTokenProvider func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error)
|
||||
TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error)
|
||||
TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error)
|
||||
LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error)
|
||||
RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error
|
||||
RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error
|
||||
@ -34,8 +34,8 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
|
||||
UnhashedToken: "",
|
||||
}, nil
|
||||
},
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) {
|
||||
return false, nil
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
return false, nil, nil
|
||||
},
|
||||
LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
@ -79,7 +79,7 @@ func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToke
|
||||
}
|
||||
|
||||
func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP,
|
||||
userAgent string) (bool, error) {
|
||||
userAgent string) (bool, *auth.UserToken, error) {
|
||||
return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent)
|
||||
}
|
||||
|
||||
|
@ -107,13 +107,14 @@ func (s *Session) RefreshTokenHook(ctx context.Context, identity *authn.Identity
|
||||
s.log.Debug("failed to get client IP address", "addr", addr, "err", err)
|
||||
ip = nil
|
||||
}
|
||||
rotated, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent)
|
||||
rotated, newToken, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent)
|
||||
if err != nil {
|
||||
s.log.Error("failed to rotate token", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if rotated {
|
||||
identity.SessionToken = newToken
|
||||
s.log.Debug("rotated session token", "user", identity.ID)
|
||||
|
||||
maxAge := int(s.loginMaxLifetime.Seconds())
|
||||
|
@ -143,9 +143,9 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) {
|
||||
|
||||
func TestSession_RefreshHook(t *testing.T) {
|
||||
s := ProvideSession(&authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) {
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
token.UnhashedToken = "new-token"
|
||||
return true, nil
|
||||
return true, token, nil
|
||||
},
|
||||
}, &usertest.FakeUserService{}, "grafana-session", 20*time.Second)
|
||||
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/components/apikeygen"
|
||||
apikeygenprefix "github.com/grafana/grafana/pkg/components/apikeygenprefixed"
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
@ -70,6 +72,7 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer
|
||||
oauthTokenService: oauthTokenService,
|
||||
features: features,
|
||||
authnService: authnService,
|
||||
singleflight: new(singleflight.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,6 +94,7 @@ type ContextHandler struct {
|
||||
oauthTokenService oauthtoken.OAuthTokenService
|
||||
features *featuremgmt.FeatureManager
|
||||
authnService authn.Service
|
||||
singleflight *singleflight.Group
|
||||
// GetTime returns the current time.
|
||||
// Stubbable by tests.
|
||||
GetTime func() time.Time
|
||||
@ -568,15 +572,15 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *contextmodel.ReqCont
|
||||
ip = nil
|
||||
}
|
||||
|
||||
// FIXME (jguer): rotation should return a new token instead of modifying the existing one.
|
||||
rotated, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
|
||||
rotated, newToken, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
|
||||
if err != nil {
|
||||
reqContext.Logger.Error("Failed to rotate token", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if rotated {
|
||||
cookies.WriteSessionCookie(reqContext, h.Cfg, reqContext.UserToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
|
||||
reqContext.UserToken = newToken
|
||||
cookies.WriteSessionCookie(reqContext, h.Cfg, newToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -24,9 +24,9 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
|
||||
tryRotateCallCount := 0
|
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
|
||||
userAgent string) (bool, error) {
|
||||
userAgent string) (bool, *auth.UserToken, error) {
|
||||
tryRotateCallCount++
|
||||
return false, nil
|
||||
return false, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
@ -46,11 +46,11 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
|
||||
ctxHdlr := getContextHandler(t)
|
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
|
||||
userAgent string) (bool, error) {
|
||||
userAgent string) (bool, *auth.UserToken, error) {
|
||||
newToken, err := util.RandomHex(16)
|
||||
require.NoError(t, err)
|
||||
token.AuthToken = newToken
|
||||
return true, nil
|
||||
return true, token, nil
|
||||
},
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user